How To Train Interpretable Neural Networks That Accurately Extrapolate From Small Data

By: Christopher Rackauckas

Re-posted from: http://www.stochasticlifestyle.com/how-to-train-interpretable-neural-networks-that-accurately-extrapolate-from-small-data/

The fundamental problems of classical machine learning are:

  1. Machine learning models require big data to train
  2. Machine learning models cannot extrapolate out of the their training data well
  3. Machine learning models are not interpretable

However, in our recent paper, we have shown that this does not have to be the case. In Universal Differential Equations for Scientific Machine Learning, we start by showing the following figure:

Indeed, it shows that by only seeing the tiny first part of the time series, we can automatically learn the equations in such a manner that it predicts the time series will be cyclic in the future, in a way that even gets the periodicity correct. Not only that, but our result is not a neural network, rather the program itself spits out the LaTeX for the differential equation:

x^\prime = \alpha x - \beta x y

y^\prime = -\delta y + \gamma x y

with the correct coefficients, which is exactly the how we generated the data.

Rather than just explaining the method, what I really want to convey in this blog post is why it intuitively makes sense that our methods work. Intuitively, we utilize all of the known scientific structure to embed as much prior knowledge as possible. This is the opposite of most modern machine learning which tries to use blackbox architectures to fits as wide of a range of behaviors as possible. Instead, what we do is we look at our problem and say, what do I know has to be true about the system, and how can I constrain the neural network to force the parameter search to only look at cases such that it is true. In the context of science, we do so by directly embedding neural networks into existing scientific simulations, essentially saying that the model is at least approximately accurate, so our neural network should only learn what the scientific model didn’t cover or simplified away. Our approach has many more applications than what we show in the paper, and if you know the underlying idea, it is quite straightforward to apply to your own work.

Quick Overview of the Universal Differential Equation Approach

The starting point for universal differential equations is the now classic work on neural ordinary differential equations. Neural ODEs are defined by the equation:

u^\prime = \text{NN}_\theta (u),

i.e. it’s an arbitrary function described as the solution to an ODE defined by a neural network. The reason why the authors went down this route was because it’s a continuous form of a recurrent neural network that then makes it natural for handling irregularly-spaced time series data.

However, this formulation can have another interesting purpose. ODEs, and differential equations in general, are well-studied because they are the language of science. Newtonian physics is described in terms of differential equations. So are Einstein’s equations, and quantum mechanics. Not only physics, but also biological models of chemical reactions in cells, population sizes in ecology, the motion of fluids, etc.: differential equations are really the language of science.

Thus it’s not a surprise that in many cases we already have and know some differential equations. They may be an approximate model, but we know this approximation is “true” in some sense. This is the jumping point for the universal differential equation. Instead of trying to make the entire differential equation be a neural network like object, since science is encoded in differential equations, it would be scientifically informative to actually learn the differential equation itself. But, in any scientific context we already know parts of the differential equation, so we might as well hard code that information as a form of prior structural information. This gives us the form of the universal differential equation:

u^\prime = f(u,U_\theta,t)

where U_\theta is an arbitrary universal approximator, i.e. a finite parameter object that can represent “any possible function”. It just so happens that neural networks are universal approximators, but note that other forms, like Chebyshev polynomials, also have this property, but neural networks do well in high dimensions and on irregular grids (some properties we utilize in some of our other examples).

What happens when we describe a differential equation in this form is that the trained neural network becomes a tangible numerical approximation of the missing function. By doing this, we can train a program that has the same exact input/output behavior as the missing term of our model. And this is precisely what we do. We assume we only know part of the differential equation:

x^\prime = \alpha x - U^1_\theta (x,y)

y^\prime = -\delta y + U^2_\theta(x,y)

and train a neural network so that way embedded neural networks defined a universal ODE that fits our data. When trained, the neural network is a numerical approximation to the missing function. But since it’s just a simple function, it’s fairly straightforward to plot it and say “hey! We were missing a quadratic term”, and there you go: interpreted back to the original generating equations. In the paper we describe how to make use of the SInDy technique to make this more rigorous through a sparse regression to a basis of possible terms, but it’s the same story that in the end we learn exactly the differential equations that generated the data, and hence the extrapolation accuracy even beyond the original time series and the nice picture:

Understanding Universal Differential Equations As A Method to Simplify Learning

Trying to approximate data might be a much harder problem then trying to understand the processes that create the data. Indeed, the Lotka-Volterra equations are a simple set of equations that are defined by 4 interpretable terms. The first simply states that the number of rabbits would grow exponentially if there wasn’t a predator eating them. The second term just states that the number of rabbits goes down when they are eaten by predators (and more predators means more eating). The third term is that more prey means more food and growth of the wolf population. Finally, the wolves die off with an exponential decay due to old age.

Lotka-Volterra with Emojis

That’s it: a simple quadratic equation that describes 4 mechanisms of interaction. Each mechanism is interpretable and can be independently verified. Meanwhile, that cyclic solution that is the data set? The time series itself is such a complicated function that you can prove that there is no way to even express its analytical solution!. This phenomena is known as emergence: simple mechanisms can give rise to complex behavior. From this it should be clear that a method which is trying to predict a time series has a much more difficult problem to solve than one that is trying to learn mechanisms!

One way to really solidify this idea is our next example. In our next example we showcase how reconstructing a partial differential equation can be straightforwardly done through universal approximators embedded within partial differential equations, what we call a universal PDE. If you want the full details behind what I will handwave here, take a look at the MIT 18.337 Scientific Machine Learning course notes or the MIT 18.S096 Applications of Scientific Machine Learning course notes. In it, there is a derivation that showcases how one can interpret partial differential equations as large systems of ODEs. Specifically, the Fisher-KPP equations that we look at in our paper:

\frac{\partial u}{\partial t} = D \frac{\partial^2 u}{\partial x^2} + u(1-u)

can be interpreted as a system of ODEs:

u_i^\prime = D(u_{i-1} - 2u_i + u_{i+1}) + u_i (1 - u_i)

However, the term in front, u_{i-1} - 2u_i + u_{i+1}, is known as a stencil. Basically, you go to each point and the solution of this operation is, sum the left and the right terms and subtract twice from the middle. Sounds familiar? It turns out that a convolutional layer from convolutional neural networks are actually just parameterized forms of stencils. For example, a picture of a two-dimensional stencil looks like:

stencil gif

where this stencil is applying the operation:

1 0 1
0 1 0
1 0 1

A convolutional layer is just a parameterized form of a stencil operation:

w1 w2 w3
w4 w5 w6
w7 w8 w9

Thus one way to approach learning spatiotemporal data which we think may come from such a generating process is:

u_i^\prime = D(w_1 u_{i-1} - w_2 u_i + w_3 u_{i+1}) + \text{NN}(u_i)

i.e., the discretization of the second spatial derivative, what’s known as diffusion, is the physically represented as the stencil of weights [1 -2 1]. Notice that in this form, the entire spatiotemporal data is described by a 1-input 1-output neural network + 3 parameters. Globally of the array of all u_i, this becomes:

u^\prime = D \text{CNN}(u) + \text{NN}.(u)

i.e. it’s a universal differential equation with a 3 parameter CNN and (the same) small R \rightarrow R neural network applied at each spatial point.

Can this tiny neural network actually fit the data? It does. But not only does it fit the data, it also is interpretable:

Trained Fisher-KPP

You see, not only did it fit and accurately match the data, it also tells us exactly the PDE that generated the data. Notice that figure (C) says that w_1 = w_3 in the fitted equation, and that w_2 = - (w_1 + w_3). This means that the convolutional neural network learned to be D[1,-2,1], exactly as the theory would predict if the only spatial physical process was diffusion. But secondly, figure (D) shows that the neural network that represented the 1-dimensional behavior seems to be quadratic. Indeed, remember from the PDE discretization:

u_i^\prime = D(u_{i-1} - 2u_i + u_{i+1}) + u_i (1 - u_i)

it really is quadratic. This tells us that the only physical generating process that could have given us this data is a diffusion equation with a quadratic reaction, i.e.

\frac{\partial u}{\partial t} = D \frac{\partial^2 u}{\partial x^2} + u(1-u)

thus interpreting the neural networks to precisely receive a PDE governing the evolution of the spatiotemporal data. Not only that, this trained form can predict beyond its data set, since if we wished for it to predict the behavior of a fluid with a different diffusion constant D, we know exactly how to change that term without retraining the neural networks since our weights in the convolutional neural network is D[1,-2,1], so we’d simply rescale those weights and suddenly have a neural network that predicts for a different underlying fluid.

Small neural network, small data, trained in an interpretable form that extrapolates, even on hard problems like spatiotemporal data.

Why It Works: Scientific Knowledge is Encoded in Structure, Not Data Points

From this explanation, it should be very clear that our approach is general, but in every application, it’s specific. We utilize prior knowledge of differential equations, like known physics or biological interactions, to try and hard code as much of the equation as possible. Then the neural networks are just stand-ins for the little pieces that are leftover. Thus the neural networks have a very easy job! They don’t have to learn very much! They just have to learn the leftovers! Thus the problem becomes easy since we imposed so much knowledge in how the neural infrastructure was made by utilizing the differential equation form to its fullest.

I like to think of it as follows. There is a certain amount of knowledge $K$ that is required to effectively learn the problem. Knowledge can come from prior information $P$ or it can come from data D. Either way, you need enough knowledge $P + D > K$ to effectively learn the model and do accurate predicting. Machine learning has gone the route of effectively relying entirely on data, but that doesn’t need to be the case. We know how physics works, and how time series relate to derivatives, so there’s no reason to force a neural network to have to learn these parts. Instead, by writing small neural networks inside of differential equations, we can embed everything that we know about the physics as true structurally-imposed prior knowledge, and then what’s left is a simple training problem. That way a big P and a small D still gives you K, and that’s how you make a “neural network” accurately extrapolate from only a small amount of training data. And now that it’s only learning a simple function, what it learned is easily interpretable through sparse regression techniques.

Software and Performance Issues

Once we had this idea of wanting to embed structure, then all of the hard work came. Essentially, in big neural network machine learning, you can get away with a lot of performance issues if 99.9% of your time is spent in the neural network’s calculations. But once we got into the regime of small data small neural network structured machine learning, our neural networks were not the big time consumer, which meant every little detail mattered. Thus we needed to hyper-optimize the solution of small ODE solves to make this a reality. As a result of our optimizations, we have easily reproducible benchmarks which showcase a 50,000x acceleration over the torchdiffeq neural ODE library. In fact, benchmarks show across the board orders of magnitude performance advantages over SciPy, MATLAB, and R’s deSolve as well. This is not a small detail, as this problem of training neural networks within scientific simulations is a costly project which takes many millions of ODE solves, and therefore these performance optimizations changed the problem from “impractical” to “reality”. Again, when very large neural networks are involved this may be masked by the cost of neural network passes itself, but in the context of small network scientific machine learning, this change was a godsend.

But the even larger difficulty that we noticed was that traditional numerical analysis ideas like stability really came into play once real physical models got involved. There is this property of ODEs called stiffness, and when it comes into play, the simple Runge-Kutta method or Adams-Bashforth-Moulton methods are no longer stable enough to accurately solve the equations. Thus when looking at the universal partial differential equations, we had to make use of a set of ODE solvers which have package implementations in Julia and Fortran. Additionally, any form of backwards solving is unconditionally unstable on the diffusion-advection equation, meaning that it ended up being a practical case where the use of simple adjoint methods like the backsolve approach of the original neural ODEs paper and torchdiffeq actually ends up diverging to infinity in finite time for any tolerance on the ODE solver. Thus we had to implement a bunch of different versions of (checkpointed) adjoint implementations in order to be accurately and efficiently train neural networks within these equations. Then, once we had a partial differential equation form, we had to build tools that would integrate with automatic differentiation to automatically specialize on sparsity. The result was a a full set of advanced methods for efficiently handling stiffness that was fully compatible with neural network backpropagation. It was only when all of this came together that the most difficult examples of what we showed actually worked. Now, our software DifferentialEquations.jl with DiffEqFlux.jl is able to handle:

  • Stiffness and ill-conditioned problems
  • Universal ordinary differential equations
  • Universal stochastic differential equations
  • Universal delay differential equations
  • Universal differential algebraic equations
  • Universal partial differential equations
  • Universal (event-driven) hybrid differential equations

all with GPUs, sparse and structured Jacobians, preconditioned Newton-Krylov, and the list of features just keeps going. This is the limitation: when real scientific models get involved, the numerical complexity drastically increases. But now this is something that at least Julia libraries have solved.

Final Thoughts

The takeaway there is that not only do you need to use all of the scientific knowledge available, but you also need to make use of all of the numerical analysis knowledge. When you combine all of this knowledge with the most recent advances of machine learning, then you get small neural networks that train on small data in a way that is interpretable and accurately extrapolates. So yes, it’s not magic: we just replaced the big data requirement with the requirement of having some prior scientific knowledge, and if you go talk to any scientist you’ll know this data source exists. I think it’s time we use it in machine learning.

The post How To Train Interpretable Neural Networks That Accurately Extrapolate From Small Data appeared first on Stochastic Lifestyle.