Tag Archives: automatic differentiation

Engineering Trade-Offs in Automatic Differentiation: from TensorFlow and PyTorch to Jax and Julia

By: Christopher Rackauckas

Re-posted from: http://www.stochasticlifestyle.com/engineering-trade-offs-in-automatic-differentiation-from-tensorflow-and-pytorch-to-jax-and-julia/

To understand the differences between automatic differentiation libraries, let’s talk about the engineering trade-offs that were made. I would personally say that none of these libraries are “better” than another, they simply all make engineering trade-offs based on the domains and use cases they were aiming to satisfy. The easiest way to describe these trade-offs is to follow the evolution and see how each new library tweaked the trade-offs made of the previous.

Early TensorFlow used a graph building system, i.e. it required users to essentially define variables in a specific graph language separate from the host language. You had to define “TensorFlow variables” and “TensorFlow ops”, and the AD would then be performed on this static graph. Control flow constructs were limited to the constructs that could be represented statically. For example, an `ifelse` function statement is very different from a conditional `if` then `else` of code because `ifelse` would semantically be the same as always calling both branches and then choosing the result, thus only having a single code path (though I say semantically because further compiler optimizations may and usually do reduce that). This static sublanguage is then represented in an intermediate representation (IR) known as XLA which then performed a lot of simplification of linear algebra, and AD was done using the simple graph representation algorithms given that there was no true control flow at this representation. While this gives a lot of efficiency (XLA is great for simplification because it can easily see the whole world), it of course had some major downsides in terms of flexibility and convenience.

Thus you can almost think of this as a source code transformation because all of the autodiff is done on essentially an IR for a language which is not the same as the host language, but for the most part it was requiring the user does the translation to the new language for the AD system which is… rather inconvenient.

PyTorch came along to solve the flexibility and convenience issues by instead using a tape-based method. It generates the code to autodiff every time you run the forward pass by simply storing the operation that it sees in a given forward pass, and then differentiates that set of operations in reverse. This “building of the tape” is done by operator overloading as part of the Tensor type PyTorch says you need to use. How it works is easy to see. For example, f(2.0) would take the first branch of the if statement and then run the while loop 5 times. So then the AD pass would take that set of operations and start running backpropagation through 5 passes of the while loop and back through the first branch. Notice that by using this form, the AD does not “see” any dynamic control flow: that was all in Python, but not in the tape. Thus the AD does not have to handle dynamic control flow, and this makes it very easy to handle a lot of odd cases of the language. The downside to this approach though is that the AD is “per value”, i.e. you cannot do a lot of optimizations on the backwards passes because you will not necessarily ever see the same backwards pass again, and this allows for a lot less optimization.

Does this harm PyTorch’s efficiency beyond repair? Well, no and yes. No it does not harm efficiency in the sense of, most machine learning algorithms are so heavily reliant on expensive kernels, such as matrix multiplication (`A*x`), `conv`, etc., so the amount of work per operation is extremely high in most ML applications that it hides the overhead of this approach. This allows the PyTorch team to spend most of its time optimizing the 2,000+ operators that it provides, and so most people in ML see PyTorch as fast because it comes with fast kernels (fast conv calls, fast GPU linear algebra) despite the AD overhead. That said, you can very easily run into cases where AD and Python interpreter overhead are not washed out. Cases of that are where your arrays are small or where a lot of scalar operations are happening, for example the Julia vs PyTorch Neural ODE benchmarks on cases matching scientific model discovery workflows you see a 100x performance improvement in Julia (even major differences without AD in the ODE and SDE solvers), and can mostly be attributed to language and AD overhead due to the small kernels used in these cases. For this reason the PyTorch team has been working on things like `torch.@jit` as a separate sublanguage that can compile and optimize differently from the rest of the code, specific to handling these cases, though there’s a lot of discussion of the long-term viability of that approach. But anyways, PyTorch has done really well because it made good choices for its domain of use.

So then TensorFlow Eager (2.0) comes around as adds dynamic control flow support in a manner similar to PyTorch as a sad attempt to get everyone back, but of course then it doesn’t play nicely will all of the XLA tooling (because it cannot see the whole graph of all possible operations for all input values to optimize it well) so it didn’t hit the TensorFlow speeds everyone was expecting, so it was kind of the worst of all worlds.

Subsequent tools then all sought ways to either expand the domains of these ideas or try to mix some of the advantages of the two sides. Jax is one of those. Jax uses non-standard interpretation to build a copy of the full code in its own IR to then perform AD on, finally lowering it to TensorFlow’s XLA for optimizations. Jax’s non-standard interpretation is kind of like operator overloading in that it has special objects walk through code in order to build out the exprs (this is called the “tracing” step). But wait, how is it able to trace the full code if there’s dynamic control flow, won’t it have the same issues as PyTorch that it only sees parts of the full code’s potential paths? Indeed that is true, and that’s why it doesn’t want you to use full dynamic control flow and instead use Jax primitives like lax.while, which are function calls that can be caught during tracing to avoid the code having true dynamic behavior at trace time. Also, for this to be true you need that what your function does can be completely determined by its inputs, i.e. the functions must be “pure”. For this reason Jax requires programming in a functional style with pure functions rather than the object-oriented standard of Python, thus a notable trade-off of the abstract interpretation approach. But what you essentially get is a more natural graph builder for TensorFlow, because at the end of the day it ends up in TensorFlow’s XLA IR, and so you get the same efficiency there but in a form that can look and feel a lot more natural. The downside of course is that you still don’t have true dynamism which is why those linked primitives exist, and why they are not well optimized as described in Jax – The Sharp Bits. However, “most” ML algorithms don’t use very much dynamism (example: recurrent neural networks know how many layers they have, they don’t have a while loop iterate to tolerance), and so “most” algorithms tend to do well in this sublanguage. In that sense, it can optimize a lot of codes rather naturally.

What about keeping dynamism in the AD?

This of course then begs the question, is it possible to keep the full dynamism of the host language in the AD system? It is possible, but it is hard. This is what a lot of the Julia AD tools have focused on with source code transformations (along with Swift for TensorFlow). However, since source code is “for humans”, it can be a rather difficult level to algorithmically work on. Thus instead these tools work on lowered IR, where these lowered representations remove a lot of the “cruft” of syntax to give a much smaller support surface. This was the core of Zygote.jl’s approach where it saw that by acting on the SSA IR it could directly support control flow like while loops without unrolling them into sets of operations (like PyTorch or TensorFlow Eager) or only supporting a sublanguage of control flow (like Jax). This is essentially done by converting while loops and other dynamic constructs into static (source code) representations that have new lines of code in there for things like stacks that keep information about the forward pass (like which branch is taken), and then these stacks are accessed and used in the generated backwards pass. Thus what code is generated is not dynamic (a while loop forward gives a for loop in reverse), but the generated backwords pass is dynamic (because it uses the stack to tell it how many times to walk the for loop). This allows AD to have a single code for all branches (unlike the tape-building forms) and thus it can optimize more like TensorFlow but in a world where the dynamic control flow is not eliminated.

Well that sounds like the best of both worlds, so why isn’t everyone using it? There’s two factors involved in that. First, accepting that your AD will have to deal with the full dynamic nature of an entire programming language means accepting a much more difficult job. The whole purpose of the AD approaches in TensorFlow/PyTorch/Jax is for these constructs to be eliminated before the AD, so they have a much smaller surface of language support required. Because of this added complexity, this pretty much guarantees you cannot use Python because it’s such a crazy language in terms of what it allows with dynamism (fun fact, the Jax folks at Google Brain did have a Python source code transform AD at one point but it was scrapped essentially because of these difficulties), and so people working on these solutions flocked to languages with clear syntax that is easy for compilers to optimize, i.e. Julia and Swift. Python has most of the ML crowd, so that creates a barrier to entry.

But even then, the problem is still very hard. In Julia it was found that Zygote acts on too high of an IR, i.e. before compiler optimizations, which then requires you do AD on unoptimized code only delete most of the work later, and so it would be better for it to go even lower. This is the reason why the Diffractor.jl project started. But there’s even a reason to act lower, since some optimization only occurs at the LLVM level, which is why Julia developers started directly building an AD system that acts on LLVM’s IR itself known as Enzyme (note that while this project included members of the Julia Lab like Valentin, because it acts at the LLVM level it is applicable to any LLVM compiled language, such as C/C++ (Clang) or Rust). There is then a trade-off that occurs with source code transform methods as you go lower and lower in the IRs which I describe in a separate post. tl;dr there: Enzyme can act after compiler optimizations so much of the higher level information might be deleted (at least, without completion of dialects like MLIR which aren’t quite ready). Enzyme only sees the barest of low level code so it may not have the high level linear algebra definitions to do all of the linear algebra simplifications, like how XLA will fuse many matrix-vector multiplications into a matrix-matrix multiplication, since some of the function calls may have been inlined and deleted. Optimizing this remining loopy code to reach BLAS speeds is thus as hard as generating looping code that reaches BLAS speeds, and history shows this is hard but not impossible. Additionally, function calls to a nonlinear solver may have already been deleted, so optimized adjoints which outperform the direct differentiation of code, like in the case of Deep Equilibrium Models (DEQs), may end up less optimized. But that lowest level allows for very efficient scalar code differentiation and mutation support. On the other hand, Diffractor uses Julia’s typed IR so it can apply higher level rules easily and consistently, and in theory it can do transformations similar to XLA (i.e. keeping BLAS calls intact and fusing them). But writing such analyses on a fully dynamic compute IR is difficult enough that it has not been done. Tooling around escape analysis and shape propagation are being built to try and enable such optimizations, but the fact remains that it’s a lot more work to do it on a language IR instead of a sublanguage graph like XLA. In theory you could have compiler passes prove that a function is semi-static in the sense of XLA and get the same optimizations as Jax or TensorFlow, but that doesn’t happen today and it’s not easy to do. The future of Julia AD systems will likely mix the Enzyme and Diffractor approaches to tackle this issue, but the clear trade-off being made here is generality at the cost of implementation complexity.

The second factor, and probably the more damning one, is that most ML codes don’t actually use that much dynamism. Recurrent neural networks, transformers, convolutions, etc. all have simpler forms of dynamism which in some sense is quite static. That’s an important trade-off most people don’t always consider: why solve problems your users don’t have? The number of layers you have do not depend on the values coming out of the layers. Support for dynamism for ML workflows is thus mostly about convenience, not necessity. When algorithms do have dynamism, in most cases you can get away with wrapping it as an operation in the language, i.e. defining a function and defining the adjoint derivative for that function. This for example is how Jax supports ODEs even though adaptive ODE solvers require knowing the calculated values in order to determine the number of steps. You cannot differentiate an ODE code with Jax, but if you use an ODE solver with a defined adjoint you are okay. While this does mean that some algorithms are not possible with Jax (at least without forgoing a lot of optimizations), and algorithms where differntiating solvers is fundamentally different from adjoint definitions can limit which performance/stability trade-offs can be made (see the supplemental section 8 for details in the case of ODEs about stability of “discrete adjoints”]), these factors seem to be rather rare in standard ML use cases which is why most people haven’t bothered to learn a new programming language to get around these issues.

That leaves us where we are today. Are more ML algorithms of the future going to require handling more dynamic structures? Is optimizing scalar and mutating code going to be important for people using AD systems? The reason why I know this story so well is because the answer for my domain, scientific machine learning (SciML), is yes. Climate models use mutation because reallocating huge buffers would greatly effect performance. Adaptive solvers on stiff equations are a fact of life, so simple adjoints used in PyTorch and Jax are unstable and simply give Inf as the gradients in these cases. Time will tell whether this physics-informed, expert-guided, science-guided, scientific machine learning domain becomes standard, but hopefully this describes how all of the choices made here were not “better” or “worse”, but instead it’s all about domain-specific engineering trade-offs.

The post Engineering Trade-Offs in Automatic Differentiation: from TensorFlow and PyTorch to Jax and Julia appeared first on Stochastic Lifestyle.

Useful Algorithms That Are Not Optimized By Jax, PyTorch, or Tensorflow

By: Christopher Rackauckas

Re-posted from: http://www.stochasticlifestyle.com/useful-algorithms-that-are-not-optimized-by-jax-pytorch-or-tensorflow/

In some previous blog posts we described in details how one can generalize automatic differentiation to give automatically stability enhancements and all sorts of other niceties by incorporating graph transformations into code generation. However, one of the things which we didn’t go into too much is the limitation of these types of algorithms. This limitation is what we have termed “quasi-static” which is the property that an algorithm can be reinterpreted as some static algorithm. It turns out that for very fundamental reasons, this is the same limitation that some major machine learning frameworks impose on the code that they can fully optimize, such as Jax or Tensorflow. This led us to the question: are there algorithms which are not optimizable within this mindset, and why? The answer is now published at ICML 2021, so lets dig into this higher level concept.

The Space of Quasi-Static Algorithms

First of all, lets come up with a concrete idea of what a quasi-static algorithm is. It’s the space of algorithms which in some way can be re-expressed as a static algorithm. Think of a “static algorithm” as one which has a simple mathematical description that does not require a full computer description, i.e. no loops, rewriting to memory, etc. As an example, let’s take a look at an example from the Jax documentation. The following is something that the Jax JIT works on:

@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x
 
print(f(3))

Notice that it’s represented by something with control flow, i.e. it is code represented with a loop, but the but the loop is not necessary We can also understand this method as 2*2*2*x or 8*x. The demonstrated example of where the JIT will fail by default is:

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x
 
# This will fail!
try:
  f(2)
except Exception as e:
  print("Exception {}".format(e))

In this case, we can see that there’s essentially two compute graphs split at x<3, and so as stated this does not have a single mathematical statement that describes the computation. You can get around this by doing lax.cond(x < 3, 3. * x ** 2, -4 * x), but notice this is a fundamentally different computation: the lax.cond form always computes both sides of the if statement before choosing which one to carry forward, while the true if statement changes its computation based on the conditional. The reason why the lax.cond form thus works with Jax's JIT compilation system is thus because it is quasi-static. The computations that will occur are fixed, even if the result is not, while the original if statement will change what is computed based on the input values. This limitation exists because Jax traces through a program to attempt to build the static compute graph under the hood, and it then attempts to do its actual transformations on this graph. Are there other kinds of frameworks that do something similar? It also turns out that the set of algorithms which are transformable into purely symbolic languages is the set of quasi-static algorithms, so something like Symbolics.jl also has a form of quasi-staticness manifest in the behaviors of its algorithms. And it’s for the same reason: in symbolic algorithms you define symbolic variables like “x” and “y”, and then trade through a program to build a static compute graph for “2x^2 + 3y” which you then treat symbolically. In the frequently asked questions, there is a question for what happens when a conversion of a function to symbolic fails. If you take a look at the example:

function factorial(x)
  out = x
  while x > 1
    x -= 1
    out *= x
  end
  out
end
 
@variables x
factorial(x)

You can see that the reason for this is because the algorithm is not representable as a single mathematical expression: the factorial cannot be written as a fixed number of multiplications because the number of multiplications is dependent on that value x you’re trying to compute x! for! The error that the symbolic language throws is “ERROR: TypeError: non-boolean (Num) used in boolean context”, which is saying that it does not know how to symbolically expand out “while x > 1” to be able to represent it statically. And this is not something that is not necessarily “fixable”, it’s fundamental to the fact that this algorithm is not able to be represented by a fixed computation and necessarily needs to change the computation based on the input.

Handling Non-Quasi-Static Algorithms in Symbolics and Machine Learning

The “solution” is to define a new primitive to the graph via “@register factorial(x)”, so that this function itself is a fixed node that does not try to be symbolically expanded. This is the same concept as defining a Jax primitive or a Tensorflow primitive, where an algorithm simply is not quasi-static and so the way to get a quasi-static compute graph is to treat the dynamic block just a function “y = f(x)” that is preordained. An in the context of both symbolic languages and machine learning frameworks, for this to work in full you also need to define derivatives of said function. That last part is the catch. If you take another look at the depths of the documentation of some of these tools, you’ll notice that many of these primitives representing non-static control flow fall outside of the realm that is fully handled.

Right there in the documentation it notes that you can replace a while loop with lax.while_loop, but that is not amenable to reverse-mode automatic differentiation. The reason is because its reverse-mode AD implementation assumes that such a quasi-static algorithm exists and uses this for two purposes, one for generating the backpass but secondly for generating the XLA (“Tensorflow”) description of the algorithm to then JIT compile optimize. XLA wants the static compute graph, which again, does not necessarily exist for this case, hence the fundamental limitation. The way to get around this of course is then to define your own primitive with its own fast gradient calculation and this problem goes away…

Or does it?

Where Can We Find The Limit Of Quasi-Static Optimizers?

There are machine learning frameworks which do not make the assumption of quasi-staticness but also optimize, and most of these things like Diffractor.jl, Zygote.jl, and Enzyme.jl in the Julia programming language (note PyTorch does not assume quasi-static representations, though TorchScript’s JIT compilation does). This got me thinking: are there actual machine learning algorithms for which this is a real limitation? This is a good question, because if you pull up your standard methods like convolutional neural networks, that’s a fixed function kernel call with a good derivative defined, or a recurrent neural network, that’s a fixed size for loop. If you want to break this assumption, you have to go to a space that is fundamentally about an algorithm where you cannot know “the amount of computation” until you know the specific values in the problem, and equation solvers are something of this form.

How many steps does it take for Newton’s method to converge? How many steps does an adaptive ODE solver take? This is not questions that can be answered a priori: they are fundamentally questions which require knowing:

  1. what equation are we solving?
  2. what is the initial condition?
  3. over what time span?
  4. with what solver tolerance?

For this reason, people who work in Python frameworks have been looking for the “right” way to treat equation solving (ODE solving, finding roots f(x)=0, etc.) as a blackbox representation. If you take another look at the Neural Ordinary Differential Equations paper, one of the big things it was proposing was the treatment of neural ODEs as a blackbox with a derivative defined by the ODE adjoint. The reason of course is because adaptive ODE solvers necessarily iterate to tolerance, so there is necessarily something like “while t < tend" which is dependent on whether the current computations are computed to tolerance. As something not optimized in the frameworks they were working in, this is something that was required to make the algorithm work.

Should You Treat Equation Solvers As a Quasi-Static Blackbox?

No it’s not fundamental to have to treat such algorithms as a blackbox. In fact, we had a rather popular paper a few years ago showing that neural stochastic differential equations can be trained with forward and reverse mode automatic differentiation directly via some Julia AD tools. The reason is because these AD tools (Zygote, Diffractor, Enzyme, etc.) do not necessarily assume quasi-static forms due to how they do direct source-to-source transformations, and so they can differentiate the adaptive solvers directly and spit out the correct gradients. So you do not necessarily have to do it in the “define a Tensorflow op” style, but which is better?

It turns out that “better” can be really hard to define because the two algorithms are not necessarily the same and can compute different values. You can boil this down to: do you want to differentiate the solver of the equation, or do you want to differentiate the equation and apply a solver to that? The former, which is equivalent to automatic differentiation of the algorithm, is known as discrete sensitivity analysis or discrete-then-optimize. The latter is continuous sensitivity analysis or optimize-then-discretize approaches. Machine learning is not the first field to come up against this problem, so the paper on universal differential equations and the scientific machine learning ecosystem has a rather long description that I will quote:

“””
Previous research has shown that the discrete adjoint approach is more stable than continuous adjoints in some cases [41, 37, 42, 43, 44, 45] while continuous adjoints have been demonstrated to be more stable in others [46, 43] and can reduce spurious oscillations [47, 48, 49]. This trade-off between discrete and continuous adjoint approaches has been demonstrated on some equations as a trade-off between stability and computational efficiency [50, 51, 52, 53, 54, 55, 56, 57, 58]. Care has to be taken as the stability of an adjoint approach can be dependent on the chosen discretization method [59, 60, 61, 62, 63], and our software contribution helps researchers switch between all of these optimization approaches in combination with hundreds of differential equation solver methods with a single line of code change.
“””

Or, tl;dr: there’s tons of prior research which generally shows that continuous adjoints are less stable than discrete adjoints, but they can be faster. We have done recent follow-ups which show these claims are true on modern problems with modern software. Specifically, this paper on stiff neural ODEs shows why discrete adjoints are more stable that continuous adjoints when training on multiscale data, but we also recently showed continuous adjoints can be much faster at gradient computations than (some) current AD techniques for discrete adjoints.

So okay, there’s a true benefit to using discrete adjoint techniques if you’re handling these hard stiff differential equations, differentiting partial differential equations, etc. and this has been known since the 80’s in the field of control theory. But other than that, it’s a wash, and so it’s not clear whether differentiating such algorithms is better in machine learning, right?

Honing In On An Essentially Non-Quasi-Static Algorithm Which Accelerates Machine Learning

This now brings us to how the recent ICML paper fits into this narrative. Is there a non-quasi-static algorithm that is truly useful for standard machine learning? The answer turns out to be yes, but how to get there requires a few slick tricks. First, the setup. Neural ODEs can be an interesting method for machine learning because they use an adaptive ODE solver to essentially choose the number of layers for you, so it’s like a recurrent neural network (or more specifically, like a residual neural network) that automatically finds the “correct” number of layers, where the number of layers is the number of steps the ODE solver decides to take. In other words, Neural ODEs for image processing are an algorithm that automatically do hyperparameter optimization. Neat!

But… what is the “correct” number of layers? For hyperparameter optimization you’d assume that would be “the least number of layers to make predictions accurately”. However, by default neural ODEs will not give you that number of layers: they will give you whatever they feel like. In fact, if you look at the original neural ODE paper, as the neural ODE trains it keeps increasing the number of layers it uses:

So is there a way to change the neural ODE to make it define “correct number of layers” as “least number of layers”? In the work Learning Differential Equations that are Easy to Solve they did just that. How they did it is that they regularized the training process of the neural ODE. They looked at the solution and noted that ODEs with have more changes going on are necessarily harder to solve, so you can transform the training process into hyperparameter optimization by adding a regularization term that says “make the higher order derivative terms as small as possible”. The rest of the paper is how to enact this idea. How was that done? Well, if you have to treat the algorithm as a blackbox, you need to define some blackbox way to defining high order derivatives which then leads to Jesse’s pretty cool formulation of Taylor-mode automatic differentiation. But no matter how you put it, that’s going to be an expensive object to compute: computing the gradient is more expensive than the forward pass, and the second derivative moreso than the gradient, and the third etc, so an algorithm that wants 6th derivatives is going to be nasty to train. With some pretty heroic work they got a formulation of this blackbox operation which takes twice as long to train but successfully does the hyperparmeter optimization.

End of story? Far from it!

The Better Way to Make Neural ODEs An Automatic Hyperparameter Optimizing Algorithm

Is there a way to make automatic hyperparameter optimization via neural ODEs train faster? Yes, and our paper makes them not only train faster than that other method, but makes it train faster than the vanilla neural ODE. We can make layer hyperparameter optimization less than free: we can make it cheaper than not doing the optimization! But how? The trick is to open the blackbox. Let me show you what a step of the adaptive ODE solver looks like:

Notice that the adaptive ODE solver chooses whether a time step is appropriate by using an error estimate. The ODE algorithm is actually constructed so that the error estimate, the estimate of “how hard this ODE is to solve”, is actually computed for free. What if we use this free error estimate as our regularization technique? It turns out that is 10x faster to train that before, while similarly automatically performing hyperparameter optimization.

Notice where we have ended up: the resulting algorithm is necessarily not quasi-static. This error estimate is computed by the actual steps of the adaptive ODE solver: to compute this error estimate, you have to do the same computations, the same while loop, as the ODE solver. In this algorithm, you cannot avoid directly differentiating the ODE solver because pieces of the ODE solver’s internal calculations are now part of the regularization. This is something that is fundamentally not optimized by methods that require quasi-static compute graphs (Jax, Tensorflow, etc.), and it is something that makes hyperparameter optimization cheaper than not doing hyperparameter optimization since the regularizer is computed for free. I just find this result so cool!

Conclusion: Finding the Limitations of Our Tools

So yes, the paper is a machine learning paper on how to do hyperparameter optimization for free using a trick on neural ODEs, but I think the general software context this sits in highlights the true finding of the paper. This is the first algorithm that I know of where there is both a clear incentive for it to be used in modern machine learning, but also, there is a fundamental reason why common machine learning frameworks like Jax and Tensorflow will not be able to treat them optimally. Even PyTorch’s TorchScript will fundamentally, due to the assumptions of its compilation process, no work on this algorithm. Those assumptions were smartly chosen because most algorithms can satisfy them, but this one cannot. Does this mean machine learning is algorithmically stuck in a rut? Possibly, because I thoroughly believe that someone working within a toolset that does not optimize this algorithm would have never found it, which makes it very thought-provoking to me.

What other algorithms are out there which are simply better than our current approaches but are worse only because of the current machine learning frameworks? I cannot wait until Diffractor.jl’s release to start probing this question deeper.

The post Useful Algorithms That Are Not Optimized By Jax, PyTorch, or Tensorflow appeared first on Stochastic Lifestyle.

Generalizing Automatic Differentiation to Automatic Sparsity, Uncertainty, Stability, and Parallelism

By: Christopher Rackauckas

Re-posted from: http://www.stochasticlifestyle.com/generalizing-automatic-differentiation-to-automatic-sparsity-uncertainty-stability-and-parallelism/

Automatic differentiation is a “compiler trick” whereby a code that calculates f(x) is transformed into a code that calculates f'(x). This trick and its two forms, forward and reverse mode automatic differentiation, have become the pervasive backbone behind all of the machine learning libraries. If you ask what PyTorch or Flux.jl is doing that’s special, the answer is really that it’s doing automatic differentiation over some functions.

What I want to dig into in this blog post is a simple question: what is the trick behind automatic differentiation, why is it always differentiation, and are there other mathematical problems we can be focusing this trick towards? While very technical discussions on this can be found in our recent paper titled “ModelingToolkit: A Composable Graph Transformation System For Equation-Based Modeling” and descriptions of methods like intrusive uncertainty quantification, I want to give a high-level overview that really describes some of the intuition behind the technical thoughts. Let’s dive in!

What is the trick behind automatic differentiation? Abstract interpretation

To understand automatic differentiation in practice, you need to understand that it’s at its core a code transformation process. While mathematically it comes down to being about Jacobian-vector products and Jacobian-transpose-vector products for forward and reverse mode respectively, I think sometimes that mathematical treatment glosses over the practical point that it’s really about code.

Take for example . If we want to take the derivative of this, then we could do , but this misses the information that we actually know analytically how to define the derivative! Using the principle that algorithm efficiency comes from problem information, we can improve this process by directly embedding that analytical solution into our process. So we come to the first principle of automatic differentiation:

If you know the analytical solution to the derivative, then replace the function with its derivative

So if you see and someone calls “derivative(f,x)“, you can do a quick little lookup to a table of rules, known as primitives, and if it’s in your table then boom you’re done. Swap it in, call it a day.

This already shows you that, with automatic differentiation, we cannot think of as just a function, just a thing that takes in values, but we have to know something about what it means semantically. We have to look at it and identify “this is sin” in order to know “replace it with cos”. This is the fundamental limitation of automatic differentiation: it has to know something about your code, more information than it takes to call or run your code. This is why many automatic differentiation libraries are tied to specific implementations of underlying numerical primitives. PyTorch understands “torch.sin“ as , but it does not understand “tf.sin“ as , which is why if you place a TensorFlow function into a PyTorch training loop you will get an error thrown about the derivative calculation. This semantic mapping is the reason for libraries like ChainRules.jl which define semantic mappings for the Julia Base library and allows extensions: by directly knowing this mapping on all of standard Julia Base, you can cover the language and achieve “differentiable programming”, i.e. all programs automatically can get derivatives.

But we’re not done. Let’s say we have . The answer is not to add this new function to the table by deriving it by hand: instead we have to come up with a way to make a function generate a derivative code whenever and are in our lookup table. The answer comes from the chain rule. I’m going to describe the forward application of the chain rule as it’s a bit simpler to derive, but a full derivation of how this is done in the reverse form is described in these lecture notes. The chain rule tells us that . Thus in order to calculate , we need to know two things: and . If we calculate both of these quantities at every stage of our code, it doesn’t matter how deep the composition goes, we will have all of the information that is required to reconstruct the result of the chain rule.

What this means is that automatic differentiation on this function can be thought of as the following translation process:

  1. Transform to and evaluate at
  2. Transform to and evaluate at
  3. Transform to . Now the second portion is the solution to the derivative

This translation process is “transform every primitive function into a tuple of (function,derivative), and transform every other function into a chain rule application using the two pieces” is abstract interpretation. This is the process where an interpreter of a code or language runs under different semantics. An interpreter written to do this process acts on the same code but interprets it differently: it changes each operation to a tuple of the solution and its derivative, instead of just the solution .

Thus the abstract interpretation version of the problem of calculating derivatives is to reimagine the problem as “at this step of the code, how should I be transforming it so that I have the information to calculate derivatives”? There are many ways to do this abstract interpretation process: operator overloading, prior static analysis to generate a new source code, etc. But there’s one question we should bring up.

Why do we see it always on differentiation? Why is there no automatic integration?

One way to start digging into this question is to answer a related question people pose to me often: if we have automatic differentiation, why do we not have automatic integration? While at face value it seems like the two should be analogues, digging deeper exposes what’s special about differentiation. If we wanted to do the integral of , then yes we can replace this with . The heart of the question is to ask about the chain rule: what’s the derivative of ? It turns out that there is no general rule for the “anti-chain rule”. A commonly known result is that the standard Gaussian probability distribution, , does not have an analytical solution to its antiderivative, and that’s just the case of and . While that is true, I don’t think that captures all that is different about integrals.

When I said “we can replace this with ” I was actually wrong: the antiderivative of is not , it’s . There is no unique solution without imposing some external context or some global information like “and F(x)=0”. Differentiation is special because it’s purely local: only knowing the value of I can know the derivative of . Integration is a well-known example of a non-local operation in mathematics: in order to know the anti-derivative at a value of , you might need to know information about some value , and sometimes it’s not necessarily obvious what that value should even be. This nonlocality manifests in other ways as well: while is not integrable, is easy to solve via a u-substitution, making and cancelling out the in-front of the . So there is no chain rule not because some things don’t have an antiderivative, but because you have nonlocality, so can be non-integrable while is. There is no chain rule because you can’t look at small pieces and transform them, instead you have to look at the problem holistically.

But this gives us a framework in order to judge whether a mathematical problem is amenable to being solved in the framework of abstract interpretation: it must be local so that we can define a step-by-step transformation algorithm, or we need to include/impose some form of context if we have alternative information.

Let’s look at a few related problems that can be solved with this trick.

Automatic Sparsity Detection in Jacobians

Recall that the Jacobian is the matrix of partial derivatives, i.e. for where and are vectors, it’s the matrix of terms . This matrix shows up in tons of mathematical algorithms, and in many cases it’s sparse, so it’s common problem to try and compute the sparsity pattern of a Jacobian. But what does this sparsity pattern mean? If you write out the analytical solution to , a zero in the Jacobian means that is not a function of . In other words, has no influence on . For an arbitrary program , can we use abstract interpretation to calculate whether influences ?

It turns out that if we make this question a little simpler then it has a simple solution. Let’s instead ask, can we use abstract interpretation to calculate whether can influence ? The reason for this change is because the previous question was non-local: is programmatically dependent on , but mathematically so you could cancel it out if you have good global knowledge of the program. So “does it influence” this output is hard, but “can it influence” the output is easy. “Can influence” is just the question of “does show up in the calculation at all?”

So we can come up with an abstract interpretation formulation to solve this problem. Instead of computing values, we can compute “influencer sets”. The output of is influenced by . The output of is influenced by . For , the influencer set of is the same as the influencer set of . So our abstract interpretation is to replace variables by influencer sets, and whenever the collide by a binary function like multiplication, we make the new influencer set be the union of the two. Otherwise we keep propagating it forward. The result of this way of running the program is that output that say “these are all of the variables which can be influencing this output variable”. If never shows up at any stage of the computation of , then there is no way it could ever influence it, and therefore . So the sparsity pattern is bounded by the influencer set.

This is the process behind the the automated sparsity tooling of the Julia programming language, which are now embedded as part of Symbolics.jl. There is a bit more you have to do if you see branching: i.e. you have to take all branches and take the union of the influencer sets (so it’s clear this cannot be generally solved with just operator overloading because you need to take non-local control of control flow). Details on the full process are described in Sparsity Programming: Automated Sparsity-Aware Optimizations in Differentiable Programming, along with extensions to things like Hessians which is all about tracking sets of linear dependence.

Automatic Uncertainty Quantification

Let’s say we didn’t actually know the value to put into our program, and instead it was some . What could we do then? In a standard physics class you probably learned a few rules for uncertainty propagation: and so on. It’s good to understand where this all comes from. If you said that was a random variable from a normal distribution with mean and standard deviation , and was an independent random variable from a normal distribution with mean and standard deviation , then would have mean and standard deviation . This means there are some local rules for propagating normal distributed random variables! What do you do for for a normally distributed input? You could approximate with its linear approximation: (this is another way to state that the derivative is the tangent vector at ). At a given value of then we just have a linear equation for some scalar , in which case we use the rule . This gives an abstract interpretation implementation to approximately computing with normal distributions! Now all we have to do is replace function calls with automatic differentiation around the mean and then propagate forward our error bounds.

This is a very crude description which you can expand to linear error propagation theory where you can more accurately treat the nonlinear propagation of variance. However, this is still missing out on whether two variables are dependent: , there’s no uncertainty there, so you need to treat that in a special way! If you think about it, “dependence propagation” is very similar to “propagating influencer sets”, which you can use to even more accurately propagate the variance terms. This gives rise to the package Measurements.jl which transforms code to make it additionally do propagation of normally distributed uncertainties.

I note in passing that Interval Arithmetic is very similarly formulated as an alternative interpretation of a program. David Sanders has a great tutorial on what this is all about and how to make use of it, so check that out for more information.

Using Context Information: Automatic Index Reduction in DAEs and Parallelism

Now let’s look at solving something a little bit deeper: simulating a pendulum. I know you’ll say “but I solved pendulum equations by hand in physics” but sorry to break it to you: you didn’t. In an early physics class you will say “all pendulums are one dimensional”, and “the angle is small, so sin(x) is approximately x” and arrive a beautiful linear ODE that you analytically solve. But the world isn’t that beautiful. So let’s look at the full pendulum. Instead you have the location of the swinger and its velocity . But you also have non-constant tension , and if we have a rigid rod we know that the distance of the swinger to the origin is constant. So the evolution of the system is in full given by:





There are differential equation solvers that can handle constraint equations, these are methods for solving differential-algebraic equations (DAEs), But if you take this code and give it to pretty much any differential equation solver it will fail. It’s not because the differential equation solver is bad, but because of the properties of this equation. See, the derivative of the last equation with respect to is zero, so you end up getting a singularity in the Newton solve that makes the stepping unstable. This singularity of the algebraic equation with respect to the algebraic variables (i.e. the ones not defined by derivative terms) is known as “higher index”. DAE solvers generally only work on index-1 DAEs, and this is an index-3 DAE. What does index-3 mean? It means if you take the last equation, differentiate it twice, and then do a substitution, you get:





This is mathematically the same system of equations, but this formulation of the equations doesn’t have the index issue, and so if you give this to a numerical solver it will work.

It turns out that you can reimagine this algorithm to be something that is also solved by a form of abstract interpretation. This is one of the nice unique features of ModelingToolkit.jl, spelled out in its recent paper. While algorithms have been written before for symbolic equation-based modeling systems, it turns out you can use an abstract interpretation process to extract the formulation of the equations, solve a graph algorithm to determine how many times you should differentiate which equations, do the differentiation using rules and structures from automatic differentiation libraries, and then regenerate the code for “the better version”. As shown in a ModelingToolkit.jl tutorial on this feature, if you do this on the pendulum equations, you can change a code that is too unstable to solve into an equation that is easy enough to solve by the simplest solvers.

And then you can go even further. As I described in the JuliaCon 2020 talk on automated code optimization, now that one is regenerating the code, you can step in and construct a graph of dependencies and automatically compute independent portions simultaneously in the generated code. Thus with no cost to the user, an abstract interpretation into symbolic graph building can be used to reconstruct and automatically parallelize code. The ModelingToolkit.jl paper takes this even further by showing how a code which is not amenable to parallelism can, after context-specific equation changes like the DAE index reduction, be transformed into an alternative variant that is suddenly embarrassingly parallel.

All of these features require a bit of extra context as to “what equations you’re solving” and information like “do you care about maintaining the same exact floating point result”, but by adding in the right amount of context, we can extend mathematical abstract interpretation to solve these alternative problems in new domains.

By the way, if this excites you and you want to see more updates like this, please star ModelingToolkit.jl.

Final Trick: Constructing a PDE Solver Out Of An ODE Solver with Abstract Interpretation

Let me end by going a little bit more mathematical. You can transform code about scalars into code about functions by using the vector-space interpretation of a function. In mathematical terms, functions are vectors in Banach spaces. Or, in the space of functions, functions are points. if you have a function and a function , then is a function too, and so is . You can do computation on these “points” by working out their expansions. For example, you can write and in their Fourier series expansions: . Approximating with finitely many expansion terms, you can represent via [a[1:n];b[1:n]], and same for . The representation of can be worked out from the finite truncation (just add the coefficients), and so can . So you can transform your computation about “functions” to “arrays of coefficients representing functions”, and derive the results for what , , etc. do on these values. This is an abstract interpretation of a program that transforms it into an equivalent program about function and measures as inputs.

Now it turns out you can formally use this to do cool things. A partial differential equation (PDE) is an ODE where where instead of your values being scalars at each time , your values are now functions at each time . So what if you represent those “functions” as “scalars” via their representation in the Sobolev space, and then put those “scalars” into the ODE solver? You automatically transform your ODE solver code into a PDE solver code. Formally, this is using a branch of PDE theory known as semigroup theory and making it a computable object.

In turns out this is something you can do. ApproxFun.jl defines types “Fun“ which represent the functions as scalars in a Banach space, and defines a bunch of operations that are allowed for such functions. I showed in a previous talk that you can slap these into DifferentialEquations.jl to have it reinterpret into this function-based differential equation solver, and then start to solve PDEs via this representation.

Conclusion: Abstract Interpretation is Powerful for Mathematics

Automatic differentiation gets all of the attention, but its really this idea of abstract interpretation that we should be looking at. “How do I change the semantics of this program to solve another problem?” is a very powerful approach. In computer science it’s often used for debugging: recompile this program into the debugging version. And in machine learning it’s often used to recompile programs into derivative calculators. But uncertainty quantification, fast sparsity detection, automatic stabilization and parallelization of differential-algebraic equations, and automatic generation of PDE solvers all arise from the same little trick. That can’t be all there is out there: the formalism and theory of abstract interpretation seems like it could be a much richer field.

Bibliography

  1. ModelingToolkit: A Composable Graph Transformation System For Equation-Based Modeling
  2. Sparsity Programming: Automated Sparsity-Aware Optimizations in Differentiable Programming
  3. Uncertainty propagation with functionally correlated quantities

The post Generalizing Automatic Differentiation to Automatic Sparsity, Uncertainty, Stability, and Parallelism appeared first on Stochastic Lifestyle.