Stepping through the process of creating a realtime median filter shader for Stage3d that pushes affectionately against the platform’s limits.

# Goal

Before setting out to make any shader, it’s a good idea to have an idea of what you’re aiming at. In this case we’re in luck as a median filter is fairly well defined. It’s a convolution filter (so each pixel has its state determined by several neighbouring pixels) that calculates the median value (place all the values in order, and take the middle one) of the surrounding pixels. In our case we’ll be making a median filter with a 3×3 area as we expect our noise to be mostly single pixels that have been altered. Also in our case we’ll be filtering on a per-channel basis, so finding the median red value, median green value, and median blue value rather than finding the colour value with the median brightness – or whatever other measurement you might like to use.

Median filters are useful in image processing for removal of small scale noise while maintaining the edges of the original image. They’re used very often in digital cameras to remove the dusting of noise that appears when shooting in low light, using a low cost sensor, or just when hit by interference.

# A First Draft

Let’s get straight in to designing our shader. I’ll mostly be writing code in Haxe’s HxSL, but for the very start let’s have some pseudo-code for a simple median filter.

declare an array of float3, length 9, named v for x in -1 to +1 { for y in -1 to +1 { v[x + y*3] = sample texture at (current location + (x,y)) } } sort each channel of v by value return v[4]

Two implementation issues appear when it comes to writing that as an actual shader.

Firstly **how do we sort values in a shader?** Shaders have no support for looping or conditional execution, so there can be no straight forward “if a is larger than b then …” or “repeat until in order”. It actually ends up being quite easy to do without such constructs; it just takes a little change of thinking.

Secondly a **lack of register space**. Stage3d’s shaders only have write-access to 8 temporary registers, which is notable for being one less than the 9 required to store all the values from the 3×3 area sample area.

# Sorting in a Shader

As every young computer scientist knows, bubble sort is the poor relation of the sorting family. But let’s use it in our shader anyway.

The advantage for us is that bubble sort is delightfully simple. Just advance through the array, compare each value to its neighbour and swap them if they’re in the wrong order. Do that enough times and you end up with the array in order. The question then is how do we perform the swapping of unordered neighbours?

Instead of thinking of things swapping positions, think of how you want the data to be after the operation.

// after sorting, v0 should have the smaller value, v1 the larger v0 = min(v0, v1); v1 = max(v0, v1); // problem! // because we just wrote to v0 it might have changed value, // so the second line doesn't do what we expect. // let's keep a copy of v0 safe in a temporary variable temp = v0; v0 = min(temp, v1); v1 = max(temp, v1);

That has placed two values in the correct order. Let’s now extend that to sort three values.

var v0 = 6; var v1 = 3; var v2 = 1; var temp; temp = v0; v0 = min(temp, v1); v1 = max(temp, v1); temp = v1; v1 = min(temp, v2); v2 = max(temp, v2); // 1st iteration of bubble sort is complete // we now know that v2 holds the maximum value // but v0 and v1 might still be in the wrong order temp = v0; v0 = min(temp, v1); v1 = max(temp, v1); // 2nd iteration of bubble sort is complete // we now know that v0, v1 and v2 hold values // in ascending order

It’s old news to our young computer scientist, but each iteration of bubble sort guarantees that the final value is correctly placed. That means we don’t need to touch it on future iterations (so future iterations are shorter), and it means that for an array of **n** elements every element is certain to be in its correct place after **n-1 **iterations. As we’re doing all this to find the median value, we actually don’t have to do the full n-1 iterations as we only need to be sure that the middle value is correctly placed in order to read off the median.

# Taking into Account the Three Channels

At this point it’s important to remember that our shader is dealing with Float3 variables (that is, each variable has an x, y, and z component representing the three colour channels) rather than the Float variables we’ve used in the examples so far. AGAL’s **min **and** max **operations work on each channel separately, so we are performing our sort on a per channel basis. This has the result that we might end up with our “median colour” being a colour that isn’t actually present in the source image – it may be built up from the red of one colour, the green of another and blue of a third. In my brief experiments it seems that this per-channel sorting actually produces better results, but it’s worth showing how to adapt the sort so that it maintains colours.

If we’re not going to sort each channel separately on its value, we’ll need some other metric by which to sort. For simplicity we’ll sort on the sum of the three channels, but you could calculate the colour’s hue, saturation, value, or whatever you like that can be expressed as a single number. Stage3d’s registers all consist of four components, and our colour values are only using three, so let’s use the fourth to store the value we’ll sort them by.

// put value to sort by in .w component v0.w = v0.x + v0.y + v0.z; v1.w = v1.x + v1.y + v1.z; v2.w = v2.x + v2.y + v2.z;

Rather than using **min** and **max** to perform the sorting, we’ll use **lt** (less than) and **gte** (greater than or equal). These operations return 1 if they are true, and 0 otherwise. Note that like most operations these are performed on all components of a variable, and as we want just a single value we give them just a single component – the one by which we want our values sorted.

firstIsSmaller = lt(v0.w, v1.w); secondIsSmaller = gte(v0.w, v1.w);

As we know that these are always either 0 or 1 and that when one is 1 the other is 0, we can construct the desired new value of v0 and v1.

temp = v0; v0 = firstIsSmaller * v0 + secondIsSmaller * v1; v1 = secondIsSmaller * v0 + firstIsSmaller * v1;

Using this technique you can do all kinds of conditional-based stuff in shaders, despite the lack of an **if** statement.

# Squeezing into a Size Eight

Being able to sort a set of variables isn’t much use if we can’t actually store them in the first place, and with just 8 temporary registers available to us things are looking pretty tight when we’re sampling 9 locations, and need a temporary variable to perform the sort.

As I wrote this section I realised it would be possible to fit all 9 Float3s and the one temporary Float3 used during sorting in the 8 available Float4 registers. It would require several of the variables to be split between multiple registers, which would make your code joyfully unreadable but it would work. But if you’re using Float4s for each sampled location then that of course could not fit in 8 registers.

Thankfully we can (I think!) arrive at the median by analysing subsets of the data. We need to find the median of 9 values, but are unable to store all 9 in memory at any one time. Instead we shall store the first 3 values, find and store their median. We do the same for the next subset of 3 values, and then the final 3. By finding the median of those three medians we should arrive at the median for the whole set.

I’m fairly sure that this is mathematically true so long as the subsets are all of equal size, all have an odd number of members, and there’s an odd number of subsets. It does seem to give the expected results, and should give something close to the true median if those requirements are not met.

var median0:Float3; var median1:Float3; var median2:Float3; var textureUV:Float2; var v0:Float3; var v1:Float3; var v2:Float3; var temp:Float3; // sample locations (-1,-1) (0,-1) (+1,-1) into v0 v1 v2 // bubble sort on v0 v1 v2 median0 = v1; // sample locations (-1,0) (0,0) (+1,0) into v0 v1 v2 // bubble sort on v0 v1 v2 median1 = v1; // sample locations (-1,+1) (0,+1), (+1,+1) into v0 v1 v2 // bubble sort on v0 v1 v2 median2 = v1; // bubble sort on median0 median1 median2 return median1;

Even with that we’re flying pretty close to the register limit. Further savings can be made by reusing registers for multiple purposes. For instance median2 is only used after v0, v1 and v2 have stopped being used, so it can share a register with one of those. Similarly the variable used to locate texture sampling is never in use at the same time as the temp variable needed for bubble sorting, so they can share a register too.

# Implementation Details

I had planned to write a section on manually converting this code into AGAL, but I think this article is long enough already. It’s a fairly robotic process anyway, just make sure you keep good notes, and keep reminding yourself to switch to Haxe as soon as you can.

Speaking of Haxe and HxSL, something I learned writing this shader is that …

median0 = v1;

… is not the same as …

median0 = v1 + [0,0,0];

If you use the first then median0 seems to change when you alter v1. Presumably this is part of an effort to save register space (after all why keep two registers that have identical contents?) but the behaviour is a little unexpected.

# Results

My motivation for writing a median filter shader was to produce some simple realtime painterly effects. It turns out that my 3×3 sample area is not enough to produce anything particularly noticeable. Which shouldn’t be surprising as the strength of the median filter in image processing is that it preserves the image while removing noise.

However after some simple modifications I’ve made two shaders that sample at 1×9 and 9×1. By applying both in turn it effectively finds the median from a 9×9 cross shaped sample area. This new shader is of limited use in noise removal, but it does produce some quite pleasing results, all running very comfortably at 60fps.

# The full fragment shader for the 3×3 median filter, in HxSL:

function fragment(texture:Texture, sizeOfTexel:Float2) { // due to limited availability of temporary variables // we cannot hold samples of all pixels in a 3x3 area. // instead we'll sample the first row, find its median, // sample the second row, find its median, // sample the third row, find its median, // then find the median of those three medians var median0:Float3; var median1:Float3; // fill v[n] with sampled values var localUV:Float2; localUV.y = tuv.y - sizeOfTexel.y; localUV.x = tuv.x - sizeOfTexel.x; var v0:Float4 = get(texture, localUV, clamp, nearest); localUV.x = tuv.x + 0; var v1:Float4 = get(texture, localUV, clamp, nearest); localUV.x = tuv.x + sizeOfTexel.x; var v2:Float4 = get(texture, localUV, clamp, nearest); var t:Float3; t = v0.xyz + [0, 0, 0]; v0.xyz = min(t, v1.xyz); v1.xyz = max(t, v1.xyz); t = v1.xyz + [0, 0, 0]; v1.xyz = min(t, v2.xyz); v2.xyz = max(t, v2.xyz); // after first iteration of bubble sort v2 is certain to be the maximum value // but median could be either v0 or v1 t = v0.xyz + [0, 0, 0]; v0.xyz = min(t, v1.xyz); v1.xyz = max(t, v1.xyz); // after a second iteration of bubble sort // v1 is also certain to have correct value // so v1 contains the median for this set of three values. // incidentally, v0 also has the correct minimum value. median0 = v1.xyz + [0, 0, 0]; // -- now do all of that again to find median of second row -- localUV.x = tuv.x - sizeOfTexel.x; v0 = get(texture, localUV, clamp, nearest); localUV.x = tuv.x + 0; v1 = get(texture, localUV, clamp, nearest); localUV.x = tuv.x + sizeOfTexel.x; v2 = get(texture, localUV, clamp, nearest); t = v0.xyz + [0, 0, 0]; v0.xyz = min(t, v1.xyz); v1.xyz = max(t, v1.xyz); t = v1.xyz + [0, 0, 0]; v1.xyz = min(t, v2.xyz); v2.xyz = max(t, v2.xyz); t = v0.xyz + [0, 0, 0]; v0.xyz = min(t, v1.xyz); v1.xyz = max(t, v1.xyz); median1 = v1.xyz + [0, 0, 0]; // -- and again for third row -- localUV.x = tuv.x - sizeOfTexel.x; v0 = get(texture, localUV, clamp, nearest); localUV.x = tuv.x + 0; v1 = get(texture, localUV, clamp, nearest); localUV.x = tuv.x + sizeOfTexel.x; v2 = get(texture, localUV, clamp, nearest); t = v0.xyz + [0, 0, 0]; v0.xyz = min(t, v1.xyz); v1.xyz = max(t, v1.xyz); t = v1.xyz + [0, 0, 0]; v1.xyz = min(t, v2.xyz); v2.xyz = max(t, v2.xyz); t = v0.xyz + [0, 0, 0]; v0.xyz = min(t, v1.xyz); v1.xyz = max(t, v1.xyz); //median2 = v1.xyz + [0, 0, 0]; // just directly use v1 instead of median2 // we now have a median for each row // bubble sort on them to find the median of medians t = median0.xyz + [0, 0, 0]; median0 = min(t, median1.xyz); median1 = max(t, median1.xyz); t = median1.xyz + [0, 0, 0]; median1 = min(t, v1.xyz); v2.xyz = max(t, v1.xyz); t = median0.xyz + [0, 0, 0]; median0 = min(t, median1); median1 = max(t, median1); out = [median1.x, median1.y, median1.z, 1]; }