// RAVU Zoom Upscaling Shader (R2 Variant, Luma)
// Rapid and Accurate Video Upscaling - Arbitrary scale version
// Edge-adaptive upscaler using radius 2 kernel for arbitrary scale factors

//!BGFX EFFECT
//!VERSION 1
//!NAME RAVU Zoom R2
//!CATEGORY Upscaling
//!DESCRIPTION Arbitrary-scale edge-adaptive upscaler with radius 2 kernel. Processes luminance channel only. Supports any output resolution.

//!TEXTURE
Texture2D INPUT;

//!SAMPLER
//!FILTER POINT
SamplerState sam_INPUT;

//!TEXTURE
Texture2D OUTPUT;

//!SAMPLER
//!FILTER LINEAR
SamplerState sam_INPUT_LINEAR;

//!TEXTURE
//!SOURCE ravu_zoom_lut2_f16.dds
//!FORMAT R16G16B16A16_FLOAT
Texture2D ravu_zoom_lut2;

//!SAMPLER
//!FILTER LINEAR
SamplerState sam_ravu_zoom_lut2;

//!COMMON
#include "prescalers.hlsli"

#define LAST_PASS 1

//!PASS 1
//!DESC RAVU-Zoom Upscale (luma, r2, compute)
//!IN INPUT, ravu_zoom_lut2
//!OUT OUTPUT
//!BLOCK_SIZE 32, 8
//!NUM_THREADS 32, 8
#define LUTPOS(x, lut_size) mix(0.5 / (lut_size), 1.0 - 0.5 / (lut_size), (x))
shared float samples[432];

#define CURRENT_PASS 1

#define GET_SAMPLE(x) dot(x.rgb, rgb2y)
#define imageStore(out_image, pos, val) imageStoreOverride(pos, val.x)
void imageStoreOverride(uint2 pos, float value) {
	float2 UV = mul(rgb2uv, INPUT.SampleLevel(sam_INPUT_LINEAR, HOOKED_map(pos), 0).rgb);
	OUTPUT[pos] = float4(mul(yuv2rgb, float3(value.x, UV)), 1.0);
}

#define INPUT_tex(pos) GET_SAMPLE(vec4(texture(INPUT, pos)))
static const float2 INPUT_size = float2(GetInputSize());
static const float2 INPUT_pt = float2(GetInputPt());

#define ravu_zoom_lut2_tex(pos) (vec4(texture(ravu_zoom_lut2, pos)))

#define HOOKED_tex(pos) INPUT_tex(pos)
#define HOOKED_size INPUT_size
#define HOOKED_pt INPUT_pt

void Pass1(uint2 blockStart, uint3 threadId) {
	// Calculate sample region for this workgroup
	ivec2 group_begin = ivec2(gl_WorkGroupID) * ivec2(gl_WorkGroupSize);
	ivec2 group_end = group_begin + ivec2(gl_WorkGroupSize) - ivec2(1, 1);
	ivec2 rectl = ivec2(floor(HOOKED_size * HOOKED_map(group_begin) - 0.5001)) - 1;
	ivec2 rectr = ivec2(floor(HOOKED_size * HOOKED_map(group_end) - 0.4999)) + 2;
	ivec2 rect = rectr - rectl + 1;

	// Load samples to shared memory
	for (int id = int(gl_LocalInvocationIndex); id < rect.x * rect.y;
		 id += int(gl_WorkGroupSize.x * gl_WorkGroupSize.y)) {
		uint y = (uint)id / rect.x, x = (uint)id % rect.x;
		samples[x + y * 36] = HOOKED_tex(HOOKED_pt * (vec2(rectl + ivec2(x, y)) + vec2(0.5, 0.5))).x;
	}
	barrier();
#if CURRENT_PASS == LAST_PASS
	uint2 destPos = blockStart + threadId.xy;
	uint2 outputSize = GetOutputSize();
	if (destPos.x >= outputSize.x || destPos.y >= outputSize.y) {
		return;
	}
#endif

	// Calculate subpixel position for LUT lookup
	vec2 pos = HOOKED_size * HOOKED_map(ivec2(gl_GlobalInvocationID));
	vec2 subpix = fract(pos - 0.5);
	pos -= subpix;
	subpix = LUTPOS(subpix, vec2(9.0, 9.0));
	vec2 subpix_inv = 1.0 - subpix;
	subpix /= vec2(2.0, 288.0);
	subpix_inv /= vec2(2.0, 288.0);

	// Get sample positions
	ivec2 ipos = ivec2(floor(pos)) - rectl;
	int lpos = ipos.x + ipos.y * 36;
	float sample0 = samples[-37 + lpos];
	float sample1 = samples[-1 + lpos];
	float sample2 = samples[35 + lpos];
	float sample3 = samples[71 + lpos];
	float sample4 = samples[-36 + lpos];
	float sample5 = samples[0 + lpos];
	float sample6 = samples[36 + lpos];
	float sample7 = samples[72 + lpos];
	float sample8 = samples[-35 + lpos];
	float sample9 = samples[1 + lpos];
	float sample10 = samples[37 + lpos];
	float sample11 = samples[73 + lpos];
	float sample12 = samples[-34 + lpos];
	float sample13 = samples[2 + lpos];
	float sample14 = samples[38 + lpos];
	float sample15 = samples[74 + lpos];

	// Structure tensor computation for edge detection
	vec3 abd = vec3(0.0, 0.0, 0.0);
	float gx, gy;
	gx = (sample4 - sample0);
	gy = (sample1 - sample0);
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.04792235409415088;
	gx = (sample5 - sample1);
	gy = (sample2 - sample0) / 2.0;
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.06153352068439959;
	gx = (sample6 - sample2);
	gy = (sample3 - sample1) / 2.0;
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.06153352068439959;
	gx = (sample7 - sample3);
	gy = (sample3 - sample2);
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.04792235409415088;
	gx = (sample8 - sample0) / 2.0;
	gy = (sample5 - sample4);
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.06153352068439959;
	gx = (sample9 - sample1) / 2.0;
	gy = (sample6 - sample4) / 2.0;
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.07901060453704994;
	gx = (sample10 - sample2) / 2.0;
	gy = (sample7 - sample5) / 2.0;
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.07901060453704994;
	gx = (sample11 - sample3) / 2.0;
	gy = (sample7 - sample6);
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.06153352068439959;
	gx = (sample12 - sample4) / 2.0;
	gy = (sample9 - sample8);
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.06153352068439959;
	gx = (sample13 - sample5) / 2.0;
	gy = (sample10 - sample8) / 2.0;
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.07901060453704994;
	gx = (sample14 - sample6) / 2.0;
	gy = (sample11 - sample9) / 2.0;
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.07901060453704994;
	gx = (sample15 - sample7) / 2.0;
	gy = (sample11 - sample10);
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.06153352068439959;
	gx = (sample12 - sample8);
	gy = (sample13 - sample12);
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.04792235409415088;
	gx = (sample13 - sample9);
	gy = (sample14 - sample12) / 2.0;
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.06153352068439959;
	gx = (sample14 - sample10);
	gy = (sample15 - sample13) / 2.0;
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.06153352068439959;
	gx = (sample15 - sample11);
	gy = (sample15 - sample14);
	abd += vec3(gx * gx, gx * gy, gy * gy) * 0.04792235409415088;

	// Eigenvalue decomposition
	float a = abd.x, b = abd.y, d = abd.z;
	float T = a + d, D = a * d - b * b;
	float delta = sqrt(max(T * T / 4.0 - D, 0.0));
	float L1 = T / 2.0 + delta, L2 = T / 2.0 - delta;
	float sqrtL1 = sqrt(L1), sqrtL2 = sqrt(L2);
	float theta = mix(mod(atan(L1 - a, b) + 3.141592653589793, 3.141592653589793), 0.0, abs(b) < 1.192092896e-7);
	float lambda = sqrtL1;
	float mu = mix((sqrtL1 - sqrtL2) / (sqrtL1 + sqrtL2), 0.0, sqrtL1 + sqrtL2 < 1.192092896e-7);

	// LUT coordinate calculation
	float angle = floor(theta * 24.0 / 3.141592653589793);
	float strength = mix(mix(0.0, 1.0, lambda >= 0.004), mix(2.0, 3.0, lambda >= 0.05), lambda >= 0.016);
	float coherence = mix(mix(0.0, 1.0, mu >= 0.25), 2.0, mu >= 0.5);
	float coord_y = ((angle * 4.0 + strength) * 3.0 + coherence) / 288.0;

	// Weighted sample accumulation from LUT
	float res = 0.0;
	vec4 w;
	w = texture(ravu_zoom_lut2, vec2(0.0, coord_y) + subpix);
	res += sample0 * w[0];
	res += sample1 * w[1];
	res += sample2 * w[2];
	res += sample3 * w[3];
	w = texture(ravu_zoom_lut2, vec2(0.5, coord_y) + subpix);
	res += sample4 * w[0];
	res += sample5 * w[1];
	res += sample6 * w[2];
	res += sample7 * w[3];
	w = texture(ravu_zoom_lut2, vec2(0.0, coord_y) + subpix_inv);
	res += sample15 * w[0];
	res += sample14 * w[1];
	res += sample13 * w[2];
	res += sample12 * w[3];
	w = texture(ravu_zoom_lut2, vec2(0.5, coord_y) + subpix_inv);
	res += sample11 * w[0];
	res += sample10 * w[1];
	res += sample9 * w[2];
	res += sample8 * w[3];
	res = clamp(res, 0.0, 1.0);

	// Output single pixel
	imageStore(out_image, ivec2(gl_GlobalInvocationID), res);
}
