jax vmap and pmap are so awesome! consider this example i worked through today...

i have a dataset of 2600 imgs across 4 hosts; so 5400 imgs per host

i.e. shape of (5400, 64, 64, 3) 1/7
for each image i want to add an additional 7 augmented versions, so i'll have 8 in total, that correspond to the 4 combos of 90deg rotations & the 2 combos from a L/R flip. an e.g. is shown here. how should i do this? a double nested loop? no! no! no! 2/7
first we start with an `augment` function that given an img, a rotation amount (0, 1, 2 or 3) and whether to flip (0, 1) returns a new image. 3/7
we want to vectorise this so it runs all 8 combos at once; to do this we can use vmap! this method if called with an image returns a batch of 8 images representing the combos (note rot=0 & flip=0 produces the original image). should we jit this too? not yet, bit more to do. 4/7
next we want to be able to run this on a batch of images, not just one. how do we do this? add another vmap! and since we want to flatten things a bit we include a reshape here too. should we jit this? not quite... 5/7
turns out the images are sharded over the 8 cores of a TPU. so the shape is actually (8, 675, 64, 64, 3) not (5400, 64, 64, 3)

so we don't want to call it directly, but via a pmap so it'll runs on all cores in parallel 6/7
note: the little bit of bookkeeping we needed to do with the labels too; each label needs to be repeated 8 times to correspond to the 8 imgs combos.

pmap(vmap(vmap(augment))) for the win! yeah yeah! 7/7
oh! forgot to say about the jit! that final pmap will compile the entire call chain with xla end to end to generate crazy fast accelerated code.
You can follow @mat_kelcey.
Tip: mention @twtextapp on a Twitter thread with the keyword “unroll” to get a link to it.

Latest Threads Unrolled:

By continuing to use the site, you are consenting to the use of cookies as explained in our Cookie Policy to improve your experience.