import { MeshLambertMaterial, MeshDepthMaterial, RGBADepthPacking, Color, DoubleSide } from 'three'
import Ids from '../Cloak/Ids'
import defaultTheme from '../defaultTheme'

export default function NodeSelectMaterial(color = new Color(defaultTheme[Ids.NODES_COLOR])) {
    const mat = new MeshLambertMaterial()
    let shader = null
    let shaderDepth = null
    let matColor = color
    mat.side = DoubleSide

    mat.setTime = (time) => {
        if (shader) {
            shader.uniforms.u_time.value = time
            if (shaderDepth) {
                shaderDepth.uniforms.u_time.value = time
            }
        }
    }

    mat.setColor = (_color) => {
        if (shader) {
            shader.uniforms.u_color.value = _color
        }
        matColor = _color
    }

    mat.doOnAfterBeforeCompile = true
    mat.onAfterBeforeCompile = null

    mat.depthMaterial = new MeshDepthMaterial({
        depthPacking: RGBADepthPacking,
    })

    mat.depthMaterial.onBeforeCompile = (_shader) => {
        shaderDepth = _shader
        shaderDepth.vertexShader = `
          varying vec2 vHighPrecisionZW;

          void main() {
            gl_Position = projectionMatrix * modelViewMatrix * vec4(position.xyz, 1.0);
            vHighPrecisionZW = gl_Position.zw;
          }
        `

        shaderDepth.fragmentShader = `
          #include <packing>
          varying vec2 vHighPrecisionZW;

          void main() {
            float fragCoordZ = (0.5 * vHighPrecisionZW[0] / vHighPrecisionZW[1] + 0.5);
            gl_FragColor = packDepthToRGBA(fragCoordZ);
          }
        `

        shaderDepth.uniforms.u_time = {
            value: 0.0,
        }
    }

    mat.onBeforeCompile = (_shader) => {
        shader = _shader

        const vertexShaderReplacements = [
            {
                from: '#include <common>',
                to: `
                  #include <common>
                  uniform vec3 u_color;

                  varying vec3 pcolor;
                `,
            },
            {
                from: '#include <begin_vertex>',
                to: `
                  #include <begin_vertex>

                  pcolor = u_color;

                  vec4 pos4 = vec4(transformed, 1.0);
                  vec3 vpos = vec3(pos4.xyz) / pos4.w;
                  vec3 normalInterp = normal + vec3(0, -0.5, 0);

                  float shininessVal = 10.0; // Shininess

                  vec3 ambientColor = pcolor;
                  vec3 diffuseColor = pcolor;
                  vec3 specularColor = pcolor;

                  vec3 lightPos = vec3(1.25, 1.25, -0.5);

                  vec3 N = normalize(normalInterp);
                  vec3 L = normalize(lightPos);// - vpos);

                  // Lambert's cosine law
                  float lambertian = max(dot(N, L), 0.0);
                  float specular = 0.0;
                  if(lambertian > 0.0) {
                    vec3 R = reflect(-L, N); // Reflected light vector
                    vec3 V = normalize(-vpos); // Vector to viewer
                    // Compute the specular term
                    float specAngle = max(dot(R, V), 0.0);
                    specular = pow(specAngle, shininessVal);
                  }

                  pcolor = vec3(ambientColor +
                                lambertian * diffuseColor +
                                specular * specularColor);
                `,
            },
        ]
        vertexShaderReplacements.forEach((rep) => {
            shader.vertexShader = shader.vertexShader.replace(rep.from, rep.to)
        })

        shader.fragmentShader = `
          //<inject_frag_uniforms>
          varying vec3 pcolor;

          void main() {
            gl_FragColor = vec4(pcolor, 1.0);

            //<inject_fragment>
          }
        `

        shader.uniforms.u_color = {
            value: matColor,
        }

        if (mat.onAfterBeforeCompile) {
            mat.onAfterBeforeCompile(_shader)
        }
    }

    return mat
}
