import { Loader, Texture, Program, Shader, Geometry, State, Buffer, Mesh, Container, TYPES } from 'pixi.js'
import { SpriteParticleRenderable, IParticleRendererCamera } from './sprite-particle-renderer'

export interface InstancedSpriteRenderable extends SpriteParticleRenderable {
	texture: PIXI.BaseTexture
	blendMode: PIXI.BLEND_MODES
}

const PFX_VERT_SHADER_PATH = 'shaders/pfx.vert'
const PFX_FRAG_SHADER_PATH = 'shaders/pfx.frag'
const STATIC_FLOATS_PER_VERT = 4
const INSTANCED_FLOATS_PER_VERT = 14
const VERTS_PER_INSTANCED_SPRITE = 4
const VERT_POSITIONS: number[] = [-0.5, -0.5, 0.5, -0.5, 0.5, 0.5, -0.5, 0.5]
const VERT_UVS: number[] = [0, 0, 1, 0, 1, 1, 0, 1]
const INSTANCED_SPRITE_INDICES: number[] = [0, 1, 2, 0, 2, 3]
const MAX_BATCHES = 32
const MAX_INSTANCES = 4096

class Batch {
	shader: Shader
	geom: Geometry
	state: State
	mesh: Mesh
	instancedVBO: Buffer
	instancedDataFloats
}

interface BufferedTexture {
	texture: any
	count: number
	index: number
}

export class InstancedSpriteBatcher {
	x: number = 0
	y: number = 0

	private static staticVBO: Buffer
	private static indexBuffer: Buffer
	private static staticVertSize = STATIC_FLOATS_PER_VERT * 4
	private static shaderLoader: Loader
	private static vertSrc: string
	private static fragSrc: string
	private static reusableSingle = [null]
	private static program: PIXI.Program
	private static staticDataFloats
	private static samplersArray: [] = []

	private ready = false
	private nextFloatIdx = 0
	private numToRender = 0

	private batches: Batch[] = []
	private currBatchIdx = 0
	private pixiRenderer
	private maxTextures = 0

	private bufferedTextures: BufferedTexture[] = []
	private currBlendMode = PIXI.BLEND_MODES.NORMAL

	constructor(pixiRenderer: PIXI.Renderer, maxTextures: number) {
		this.pixiRenderer = pixiRenderer
		this.maxTextures = maxTextures

		if (!InstancedSpriteBatcher.shaderLoader) {
			InstancedSpriteBatcher.shaderLoader = new Loader()
			InstancedSpriteBatcher.shaderLoader.add([PFX_VERT_SHADER_PATH, PFX_FRAG_SHADER_PATH])
			InstancedSpriteBatcher.shaderLoader.load(this._onShadersLoaded.bind(this))
		} else {
			// Maybe it's still loading
			if (!InstancedSpriteBatcher.vertSrc) {
				InstancedSpriteBatcher.shaderLoader.onComplete.add(() => {
					this._init()
				})
			} else {
				this._init()
			}
		}
	}

	reset() {
		this.currBatchIdx = 0
	}

	flush() {
		if (!this.ready || this.numToRender <= 0) {
			return
		}

		this._bindTextures()

		const batch = this.batches[this.currBatchIdx]
		batch.instancedVBO.update(batch.instancedDataFloats)
		batch.mesh.size = 6
		batch.geom.instanceCount = this.numToRender
		batch.mesh.visible = this.numToRender > 0
		batch.mesh.blendMode = this.currBlendMode
		batch.mesh.render(this.pixiRenderer)

		// Move to next vbo for next batch, otherwise we'll have a CPU/GPU stall
		this.currBatchIdx = (this.currBatchIdx + 1) % MAX_BATCHES
		this.bufferedTextures = []
		this.nextFloatIdx = 0
		this.numToRender = 0
	}

	beginRender(camera?: IParticleRendererCamera) {
		if (!this.ready) {
			return
		}

		const batch = this.batches[this.currBatchIdx]
		const shader = batch.shader

		if (camera) {
			shader.uniforms.u_camera[0] = camera.x
			shader.uniforms.u_camera[1] = camera.y
			shader.uniforms.u_camera[2] = camera.zoom
			shader.uniforms.u_camera[3] = 1 // >0 means "use the camera" to the shader
			shader.uniforms.u_screenHalfSize[0] = camera.halfWidth
			shader.uniforms.u_screenHalfSize[1] = camera.halfHeight
		}

		this.nextFloatIdx = 0
		this.numToRender = 0
	}

	addInstancedSprite(sprite: InstancedSpriteRenderable) {
		InstancedSpriteBatcher.reusableSingle[0] = sprite
		this.addInstancedSprites(InstancedSpriteBatcher.reusableSingle, 1)
	}

	// addInstancedSprites assumes they all come from the same texture and blend mode.
	// If they don't, use addInstancedSprite n times
	addInstancedSprites(sprites: InstancedSpriteRenderable[], numInstancedSprites: number) {
		if (!this.ready || numInstancedSprites === 0) {
			return
		}

		let flushed = false

		// Check if these instances will push us over the edge of the VBO - if so, flush
		// Also check if the incoming sprite(s) use a different blend mode. If so, flush.
		if (this.numToRender + numInstancedSprites >= MAX_INSTANCES || this.currBlendMode !== sprites[0].blendMode) {
			this.flush()
			flushed = true
		}

		this.currBlendMode = sprites[0].blendMode

		if (!this._textureSeenThisBatch(sprites[0].texture)) {
			// Check if adding this texture would put us over the edge.
			// If so, flush (unless we just flushed)
			if (!flushed && this.bufferedTextures.length === this.maxTextures) {
				this.flush()
			}
			this.bufferedTextures.push({ texture: sprites[0].texture, count: 1, index: this.bufferedTextures.length })
		}

		this._updateInstancedSpriteBuffer(sprites, numInstancedSprites)
		this.numToRender += numInstancedSprites
	}

	private _textureSeenThisBatch(tex: PIXI.BaseTexture) {
		const textures = this.bufferedTextures
		for (let i = 0, len = textures.length; i < len; ++i) {
			if (textures[i].texture === tex) {
				textures[i].count++
				return true
			}
		}
		return false
	}

	private _onShadersLoaded(loader, res) {
		if (!InstancedSpriteBatcher.vertSrc) {
			InstancedSpriteBatcher.vertSrc = InstancedSpriteBatcher.shaderLoader.resources[PFX_VERT_SHADER_PATH].data
			InstancedSpriteBatcher.fragSrc = InstancedSpriteBatcher.shaderLoader.resources[PFX_FRAG_SHADER_PATH].data
		}

		this._init()
	}

	private _getTextureIdx(tex: PIXI.BaseTexture) {
		for (let i = 0, len = this.bufferedTextures.length; i < len; ++i) {
			if (this.bufferedTextures[i].texture === tex) {
				return i
			}
		}
		return 0 // shouldn't happen...
	}

	private _updateInstancedSpriteBuffer(sprites: InstancedSpriteRenderable[], numInstancedSprites: number) {
		let floatIdx = this.nextFloatIdx
		const batch = this.batches[this.currBatchIdx]
		const instancedDataFloats = batch.instancedDataFloats
		// Always accepts instanced sprites in batches of the same texture, so I can assume that here and
		// avoid per-sprite lookups
		const textureIdx = this._getTextureIdx(sprites[0].texture)

		for (let i = 0; i < numInstancedSprites; ++i) {
			const p = sprites[i]
			const pos = p.pos
			const scale = p.scale
			const color = p.color
			const uvExtents = p.uvExtents

			// ins_transRot
			instancedDataFloats[floatIdx++] = pos[0]
			instancedDataFloats[floatIdx++] = pos[1]
			instancedDataFloats[floatIdx++] = p.rot
			instancedDataFloats[floatIdx++] = textureIdx

			// ins_size
			instancedDataFloats[floatIdx++] = scale[0]
			instancedDataFloats[floatIdx++] = scale[1]

			// ins_color (with pre-multiplied alpha)
			instancedDataFloats[floatIdx++] = color[0]
			instancedDataFloats[floatIdx++] = color[1]
			instancedDataFloats[floatIdx++] = color[2]
			instancedDataFloats[floatIdx++] = color[3]

			// ins_uvTransform
			instancedDataFloats[floatIdx++] = uvExtents[0]
			instancedDataFloats[floatIdx++] = uvExtents[1]
			instancedDataFloats[floatIdx++] = uvExtents[2]
			instancedDataFloats[floatIdx++] = uvExtents[3]
		}

		this.nextFloatIdx = floatIdx
	}

	// Adapted from PIXI BatchShaderGenerator
	private _generateSampleSrc(maxTextures) {
		let src = ''

		src += '\n'
		src += '\n'

		for (let i = 0; i < maxTextures; i++) {
			if (i > 0) {
				src += '\nelse '
			}

			if (i < maxTextures - 1) {
				src += `if(vTextureId < ${i}.5)`
			}

			src += '\n{'
			src += `\n\tcolor = texture2D(uSamplers[${i}], vTextureCoord);`
			src += '\n}'
		}

		src += '\n'
		src += '\n'

		return src
	}

	private _init() {
		if (!InstancedSpriteBatcher.staticVBO) {
			const vertexSrc = InstancedSpriteBatcher.vertSrc
			const fragSrc = InstancedSpriteBatcher.fragSrc
			let modifiedFragSrc = fragSrc.replace(/%count%/gi, `${this.maxTextures}`)
			modifiedFragSrc = modifiedFragSrc.replace(/%forloop%/gi, this._generateSampleSrc(this.maxTextures))

			for (let i = 0; i < this.maxTextures; ++i) {
				InstancedSpriteBatcher.samplersArray.push(i)
			}

			// Program.from uses a cache, so don't need to worry about duplicate programs being created
			InstancedSpriteBatcher.program = Program.from(vertexSrc, modifiedFragSrc, 'pfxShader')

			InstancedSpriteBatcher.staticDataFloats = new Float32Array(STATIC_FLOATS_PER_VERT * VERTS_PER_INSTANCED_SPRITE)
			const indexData = new Uint16Array(INSTANCED_SPRITE_INDICES)

			let ofs = 0
			for (let i = 0; i < 4; ++i) {
				InstancedSpriteBatcher.staticDataFloats[ofs++] = VERT_POSITIONS[i * 2 + 0]
				InstancedSpriteBatcher.staticDataFloats[ofs++] = VERT_POSITIONS[i * 2 + 1]
				InstancedSpriteBatcher.staticDataFloats[ofs++] = VERT_UVS[i * 2 + 0]
				InstancedSpriteBatcher.staticDataFloats[ofs++] = VERT_UVS[i * 2 + 1]
			}

			InstancedSpriteBatcher.staticVBO = new Buffer(InstancedSpriteBatcher.staticDataFloats, true, false)
			InstancedSpriteBatcher.indexBuffer = new Buffer(indexData, true, true)
		}

		const uniforms = {
			// u_tex: null,
			u_camera: [0, 0, 0, 0],
			u_screenHalfSize: [0, 0],
			uSamplers: InstancedSpriteBatcher.samplersArray,
		}

		for (let batchIdx = 0; batchIdx < MAX_BATCHES; ++batchIdx) {
			const batch = new Batch()
			batch.shader = new Shader(InstancedSpriteBatcher.program, uniforms)
			batch.geom = new Geometry()
			batch.state = new State()

			batch.instancedDataFloats = new Float32Array(MAX_INSTANCES * INSTANCED_FLOATS_PER_VERT)
			batch.instancedVBO = new Buffer(batch.instancedDataFloats, false, false)

			batch.geom.addAttribute('aPosition', InstancedSpriteBatcher.staticVBO, 2, false, TYPES.FLOAT, InstancedSpriteBatcher.staticVertSize, 0)
			batch.geom.addAttribute('aUV', InstancedSpriteBatcher.staticVBO, 2, false, TYPES.FLOAT, InstancedSpriteBatcher.staticVertSize, 8)

			const instVertSize = INSTANCED_FLOATS_PER_VERT * 4
			batch.geom.addAttribute('ins_transRot', batch.instancedVBO, 4, false, TYPES.FLOAT, instVertSize, 0, true)
			batch.geom.addAttribute('ins_size', batch.instancedVBO, 2, false, TYPES.FLOAT, instVertSize, 16, true)
			batch.geom.addAttribute('ins_color', batch.instancedVBO, 4, false, TYPES.FLOAT, instVertSize, 24, true)
			batch.geom.addAttribute('ins_uvTransform', batch.instancedVBO, 4, false, TYPES.FLOAT, instVertSize, 40, true)

			batch.geom.addIndex(InstancedSpriteBatcher.indexBuffer)
			batch.state.blend = true

			batch.mesh = new Mesh(batch.geom, batch.shader, batch.state)
			batch.mesh.blendMode = this.currBlendMode

			this.batches.push(batch)
		}

		this.ready = true
	}

	private _bindTextures() {
		const textureSystem = this.pixiRenderer.texture
		const textures = this.bufferedTextures

		textures.sort((a, b) => b.count - a.count)

		for (let i = 0, len = textures.length; i < len; i++) {
			textureSystem.bind(textures[i].texture, textures[i].index)
		}
	}
}
