Surveying numerical methods (finite difference methods) and physics-informed neural networks to solve a 1D heat equation. This post was heavily inspired by:

Introduction

Physics-Informed Machine Learning (PIML) is an exciting subfield of Machine Learning that aims to incorporate physical laws and/or constraints into statistical machine learning. The representations of the laws and constraints can be categorized into three groups (with decreasing strength of inductive bias):

  • Partial differential equations (PDE)
  • Symmetry: translation, rotation invariant.
  • And intuitive physical constraints.

The PINN method incorporates PDE into the learning problem by adding PDE as a regularization term into the machine learning loss term.

Heat equations

This instance of the 1D heat equation describes how the temperature of an insulated rod changes over time (transient state) at any point on the rod, where the two ends of the rod are kept at a constant temperature of \(0^o C\) and the initial temperature of the rod was given by a function of location \(x\).

$$ \begin{equation} \begin{aligned} PDE: & & u_t = \alpha^2 u_{xx} & & 0 < x< 1 & & 0 < t < \infty \\ BCs: & & \begin{cases} u(0, t) = 0\\ u(1, t) = 0 \end{cases} & & 0 < t < \infty \\ IC: & & u(x, 0) = \sin(2\pi x) & & 0 \leq x \leq 1 \end{aligned} \end{equation} $$

Solving heat equation with variables seperation

Suppose that we can factorize \(u(x, t) = X(x)T(t)\), from the PDE we have:

$$ \begin{equation} \begin{aligned} & X(x)T^\prime(t) = \alpha^2 X^{\prime\prime}(x)T(t)\\ \implies & \frac{T^\prime(t)}{\alpha^2 T(t)} = \frac{X^{\prime\prime}(x)}{X(x)} = \mu \\ \implies & \begin{cases} T^\prime(t) - \mu\alpha^2 T(t) = 0 & & (2a) \\ X^{\prime\prime}(x) - \mu X(x) = 0 & & (2b) \end{cases} \end{aligned} \end{equation} $$

From equation (2a), \(T(t) = Ae^{\mu\alpha^2t}\). This implies \(\mu\) must be negative so that \(T\) doesn’t go to \(\infty\). Let \(\mu = -\lambda^2\), so \(T(t) = Ae^{-\lambda^2\alpha^2t}\). Replacing into (2), we have:

$$ \begin{equation} \begin{aligned} & X^{\prime\prime}(x) + \lambda^2 X(x) = 0 \\ \implies & X(x) = B \sin\lambda x + C\cos\lambda x \end{aligned} \end{equation} $$

Substitute \(T(t), X(x)\) into \(u(x, t)\):

$$ \begin{equation} u(x, t) = e^{-\lambda^2\alpha^2 t}(A\sin\lambda x + B\cos\lambda x) \end{equation} $$

Subsititute this into boundary conditions:

$$ \begin{equation} \begin{aligned} & \begin{cases} u(0, t) = 0 \\ u(1, t) = 0 \end{cases} \\ \implies & \begin{cases} e^{-\lambda^2\alpha^2 t}(A \sin 0 + B \cos 0)= 0 \\ e^{-\lambda^2\alpha^2 t}(A \sin \lambda + B \cos \lambda) = 0 \\ \end{cases}\\ \implies & \begin{cases} B = 0 \\ \lambda = n\pi & n = 1, 2, \cdots \end{cases} \end{aligned} \end{equation} $$

So for a given \(n\), we have a particular solution for \(u(x, t)\):

$$ \begin{equation} u_n(x, t) = A_n e^{-n^2\pi^2\alpha^2 t} \sin n\pi x \end{equation} $$

And the general solution for \(u(x, t)\):

$$ \begin{equation} u(x, t) = \sum_{n=1}^\infty{A}_n e^{-n^2\pi^2\alpha^2t} \sin n\pi x \end{equation} $$

Where \(A_n\) is given by:

$$ \begin{equation} A_n = 2\int_0^1 \sin 2\pi x\sin n\pi x dx = \begin{cases} 0 \quad \text{if }n \neq 2\\ 1 \quad \text{if }n = 2 \end{cases} \end{equation} $$

Finally, we have the solution to the PDE:

$$ \begin{equation} \blue{u(x, t) = e^{-4\pi^2\alpha^2t}\sin 2\pi x} \end{equation} $$

Finite Difference Method

Numerical approximation of first and second order derivative

First Order Forward Difference

Consider a Taylor series expansion of \(\phi(x)\) about point \(x_i\):

$$ \begin{equation} \begin{aligned} & \phi(x_i + \delta x) = \phi(x_i) + \frac{\partial \phi}{\partial x}\bigg\vert_{x_i} \delta x + \frac{\partial^2 \phi}{\partial x^2}\bigg\vert_{x_i} \frac{\delta x^2}{2!} + \frac{\partial^3 \phi}{\partial x^3}\bigg\vert_{x_i} \frac{\delta x^3}{3!} + \cdots \\ \end{aligned} \end{equation} $$

Replace \(\delta x = \Delta x \ll 1\) in equation (10):

$$ \begin{aligned} & \phi(x_i + \delta x) = \phi(x_i) + \frac{\partial \phi}{\partial x}\bigg\vert_{x_i} \Delta x + \frac{\partial^2 \phi}{\partial x^2}\bigg\vert_{x_i} \frac{\Delta x^2}{2!} + \frac{\partial^3 \phi}{\partial x^3}\bigg\vert_{x_i} \frac{\Delta x^3}{3!} + \cdots \\ \implies & \frac{\partial \phi}{\partial x}\bigg\vert_{x_i} = \frac{\phi(x_i +\Delta x) - \phi(x_i)}{\Delta x} \red{\underbrace{ - \frac{\partial^2 \phi}{\partial x^2}\bigg\vert_{x_i} \frac{\Delta x}{2!} - \frac{\partial^3 \phi}{\partial x^3}\bigg\vert_{x_i} \frac{\Delta x^2}{3!} - \cdots }_{\text{Truncation error: }\mathcal{O}(\Delta x)}}\\ & \blue{\approx \frac{\phi(x_i +\Delta x) - \phi(x_i)}{\Delta x}} \end{aligned} $$

Note that in this tutorial, the truncation error is \(\mathcal{O}(\Delta x^2)\). I haven’t been able to understand why yet!!!.

First Order Backward Difference

Replace \(\delta x = -\Delta x, \Delta x \ll 1\) in equation (10):

$$ \begin{equation} \begin{aligned} & \phi(x_i + \delta x) = \phi(x_i) - \frac{\partial \phi}{\partial x}\bigg\vert_{x_i} \Delta x + \frac{\partial^2 \phi}{\partial x^2}\bigg\vert_{x_i} \frac{\Delta x^2}{2!} - \frac{\partial^3 \phi}{\partial x^3}\bigg\vert_{x_i} \frac{\Delta x^3}{3!} + \cdots \\ \implies & \frac{\partial \phi}{\partial x}\bigg\vert_{x_i} = \frac{\phi(x_i) - \phi(x_i - \Delta x)}{\Delta x} \red{\underbrace{ + \frac{\partial^2 \phi}{\partial x^2}\bigg\vert_{x_i} \frac{\Delta x}{2!} - \frac{\partial^3 \phi}{\partial x^3}\bigg\vert_{x_i} \frac{\Delta x^2}{3!} + \cdots }_{\text{Truncation error: }\mathcal{O}(\Delta x)}}\\ & \blue{\approx \frac{\phi(x_i) - \phi(x_i - \Delta x)}{\Delta x}} \end{aligned} \end{equation} $$

Second Order Central Difference

Replace in equation (10):

  • \(\delta x = \Delta x\)

$$ \begin{equation} \begin{aligned} & \phi(x_i + \Delta x) = \phi(x_i) + \frac{\partial \phi}{\partial x}\bigg\vert_{x_i} \Delta x + \frac{\partial^2 \phi}{\partial x^2}\bigg\vert_{x_i} \frac{\Delta x^2}{2!} + \frac{\partial^3 \phi}{\partial x^3}\bigg\vert_{x_i} \frac{\Delta x^3}{3!} + \cdots \end{aligned} \end{equation} $$

  • \(\delta x = -\Delta x\)

$$ \begin{equation} \begin{aligned} & \phi(x_i - \Delta x) = \phi(x_i) - \frac{\partial \phi}{\partial x}\bigg\vert_{x_i} \Delta x + \frac{\partial^2 \phi}{\partial x^2}\bigg\vert_{x_i} \frac{\Delta x^2}{2!} - \frac{\partial^3 \phi}{\partial x^3}\bigg\vert_{x_i} \frac{\Delta x^3}{3!} + \cdots \end{aligned} \end{equation} $$

Adding equation (12) and (13) we have:

$$ \begin{equation} \begin{aligned} & \phi(x_i + \Delta x) + \phi(x_i - \Delta x) = 2 \phi(x_i) + 2 \frac{\partial^2 \phi}{\partial x^2}\bigg\vert_{x_i} \frac{\Delta x^2}{2!} + 2 \frac{\partial^4 \phi}{\partial x^4}\bigg\vert_{x_i} \frac{\Delta x^4}{4!} + \cdots\\ \implies & \frac{\partial^2 \phi}{\partial x^2}\bigg\vert_{x_i} = \frac{\phi(x_i + \Delta x) - 2\phi(x_i) + \phi(x - \Delta x)}{\Delta x^2} \red{\underbrace{ - 2 \frac{\partial^4 \phi}{\partial x^4}\bigg\vert_{x_i} \frac{\Delta x^2}{4!} - \cdots}_{\mathcal{O}(\Delta x^2)}}\\ & \blue{ \approx \frac{\phi(x_i + \Delta x) - 2\phi(x_i) + \phi(x - \Delta x)}{\Delta x^2} } \end{aligned} \end{equation} $$

Finite Difference Method for the Heat Equation

Discretize the domain \(\mathcal{D} = (0, 1) \times (0, T)\) by constructing a grid \(\{x_i\}_{i=1\cdots N} \times \{t_m\}_{m=1\cdots M}\). Where:

  • \(x_i = (i - 1) \Delta x,\quad \Delta x = \frac{1}{N - 1}\)
  • \(t_m = (m - 1) \Delta t,\quad \Delta t = \frac{T}{M - 1}\)

Let \(u(x, t)\) be the true solution to the PDE

Forward Time, Centered Space (FTCS)

  • Using First Order Forward Difference (equation 10) to approximate parital derivative of \(u\) at a grid point \((x_i, t_m)\):

    $$ \begin{equation} \begin{aligned} \frac{\partial u}{\partial t} \bigg\vert_{x=x_i, t=t_m} & = \frac{u(x_i, t_m + \Delta t) - u(x_i, t_m)}{\Delta t} + \mathcal{O}(\Delta t)\\ & \approx \frac{u_i^{m+1}-u_i^m}{\Delta t} \end{aligned} \end{equation} $$

  • Using Second Order Central Difference (equation 14) to approximate the second order partial derivative of \(u\) with respect to \(x\) at the grid point: $$ \begin{equation} \begin{aligned} \frac{\partial^2 u}{\partial x^2} \bigg\vert_{x=x_i, t=t_m} & = \frac{u(x_i + \Delta x, t_m) - 2 u(x_i, t_m) + u(x_i - \Delta x, t_m)}{\Delta x^2} + \mathcal{O}(\Delta x^2)\\ & \approx \frac{u_{i+1}^m - 2 u_i^m + u_{i-1}^m}{\Delta x^2} \end{aligned} \end{equation} $$

Where \(u_i^m\) is the numerical approximation of true function evaluated at the grid point \((x_i, t_m)\). Replacing equation (15) and (16) into the LHS and RHS of the PDE in (1):

$$ \begin{equation} \begin{aligned} & \frac{u_i^{m+1}-u_i^m}{\Delta t} = \alpha^2 \frac{u_{i+1}^m - 2 u_i^m + u_{i-1}^m}{\Delta x^2}\\ \implies & u_i^{m+1} = u_i^m + \frac{\alpha^2 \Delta t}{\Delta x^2} (u_{i+1}^m - 2 u_i^m + u_{i-1}^m) \\ & = \blue{u_i^m(1 - 2r) + r(u_{i+1}^m + u_{i-1}^m)} \end{aligned} \end{equation} $$

Where \(r = \frac{\alpha^2\Delta t}{\Delta x^2}\). In order for \(u(x, t)\) reach steady state, \(r\) must be smaller than \(\frac{1}{2}\). The proof was provided in Von Neumann Stability Analysis.

TODO: Haven’t understood yet !!!

Equation (17) allows us to sequentially compute the approximation \(u_i^m\) at any point \((x_i, t_m)\), where \(u_i^1 = u(x_i, 0), i = 1\cdots N\) were given by the initial and boundary conditions. In matrix notation, the series of equation can be written as:

$$ \begin{equation} \begin{aligned} \begin{bmatrix} 1 & 0 & 0 & 0 & 0 & 0 & 0 \\ r & 1 - 2r & r & \cdots & 0 & 0 & 0 \\ \vdots & \vdots &\vdots &\ddots &\vdots &\vdots &\vdots \\ 0 &0 & 0 & \cdots & r & 1 - 2r & r \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 \end{bmatrix} \underbrace{\begin{bmatrix} u_1^m \\ u_2^m \\ \vdots \\ u_{N-1}^m \\ u_N^m \\ \end{bmatrix}}_{u^m} = \underbrace{\begin{bmatrix} u_1^{m+1} \\ u_2^{m+1} \\ \vdots \\ u_{N-1}^{m+1} \\ u_N^{m+1} \\ \end{bmatrix}}_{u^{m+1}} \end{aligned} \end{equation} $$

Note that \(u_1^m\), \(u_N^m\) are always equal to its value in the next time step. This is due to the boundary condition, the temperature at the boundary is always \(0\).

(code) Implementation of FTCS scheme solve_fdm()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np 


def solve_fdm(N: int, M: int, T: float):
    """
    solving 1D heat equation:
        PDE: u_t = u_xx (\alpha^2 = 1)
        BCs: u(0, t) = u(1, t) = 0
        ICs: u(x, 0) = x - x**2

    args:
        - N, M  : number of collocation points
            in spacial and temporal dimension
        - T     : solving from t = 0 to T
    """

    # constructing the grid
    dx = 1 / (N - 1) # 0 <= x <= 1
    dt = T / (M - 1) # 0 < t <= T

    r = dt/dx**2 # (alpha = 1)

    # Condition for numerical stability
    assert r < .5, ValueError(f"Choose smaller r, r={r:.4f}")

    x_grid = np.linspace(0, 1, N)

    # approximate the result 
    U = np.zeros((N, M)) # already satisfied the BCs

    # IC impose initial condition
    ic = lambda x: np.sin(2 * np.pi * x) 
    U[:, 0] = np.vectorize(ic)(x_grid)

    # kernel to approximate 2nd derivative of u wrt x
    ker = np.array([1., -2., 1.], dtype=np.float64)

    for i in range(1, M):
        ut = np.convolve(U[:, i - 1], ker, mode="same")
        U[:,i] = U[:, i-1] + r * ut

    return U

1
2
# solving the PDE with 100x4000 grid from t=0 to t=0.2
U = solve_fdm(100, 4000, .2)
imgimg2

Backward Time, Centered Space (BTCS)

  • Using First Order Backward Difference (equation 11) to approximate parital derivative of \(u\) at a grid point \((x_i, t_m)\):

    $$ \begin{equation} \begin{aligned} \frac{\partial u}{\partial t} \bigg\vert_{x=x_i, t=t_m} & = \frac{u(x_i, t_m) - u(x_i, t_m - \Delta t)}{\Delta t} + \mathcal{O}(\Delta t)\\ & \approx \frac{u_i^{m}-u_i^{m-1}}{\Delta t} \end{aligned} \end{equation} $$

  • Using Second Order Central Difference (equation 14) to approximate the second order partial derivative of \(u\) with respect to \(x\) at the grid point: $$ \begin{equation} \begin{aligned} \frac{\partial^2 u}{\partial x^2} \bigg\vert_{x=x_i, t=t_m} & = \frac{u(x_i + \Delta x, t_m) - 2 u(x_i, t_m) + u(x_i - \Delta x, t_m)}{\Delta x^2} + \mathcal{O}(\Delta x^2)\\ & \approx \frac{u_{i+1}^m - 2 u_i^m + u_{i-1}^m}{\Delta x^2} \end{aligned} \end{equation} $$

Replacing equation (19), and (20) into LHS and RHS of the PDE in (1) respectively we have:

$$ \begin{equation} \begin{aligned} & \frac{u_i^{m}-u_i^{m-1}}{\Delta t} = \alpha^2 \frac{u_{i+1}^m - 2 u_i^m + u_{i-1}^m}{\Delta x^2}\\ \implies & u_i^{m-1} = u_i^m - \frac{\alpha^2 \Delta t}{\Delta x^2} (u_{i+1}^m - 2 u_i^m + u_{i-1}^m) \\ & = \blue{u_i^m(1 + 2r) - r(u_{i+1}^m + u_{i-1}^m)} \end{aligned} \end{equation} $$

Where \(r = \frac{\alpha^2\Delta t}{\Delta x^2}\). Rewriting equation(21) in matrix notation:

$$ \begin{equation} \begin{aligned} \underbrace{\begin{bmatrix} 1 & 0 & 0 & 0 & 0 & 0 & 0 \\ -r & 1 + 2r & -r & \cdots & 0 & 0 & 0 \\ \vdots & \vdots &\vdots &\ddots &\vdots &\vdots &\vdots \\ 0 &0 & 0 & \cdots & -r & 1 + 2r & -r \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 \end{bmatrix}}_A \underbrace{\begin{bmatrix} u_1^m \\ u_2^m \\ \vdots \\ u_{N-1}^m \\ u_N^m \\ \end{bmatrix}}_{\mathbf{u}^m} = \underbrace{\begin{bmatrix} u_1^{m-1} \\ u_2^{m-1} \\ \vdots \\ u_{N-1}^{m-1} \\ u_N^{m-1} \\ \end{bmatrix}}_{\mathbf{u}^{m-1}} \end{aligned} \end{equation} $$

So that we can sequentially compute the next state by solving the system of linear equations in (22):

$$ \blue{ \mathbf{u}^{m} = A^{-1} \mathbf{u}^{m-1}; \quad m = 2,\cdots M } $$

Where \(\mathbf{u}_i^1\) are given by the initial and boundary conditions. Unlike FTCS, BTCS are unconditionally stable with respect to the choice of \(r\). Therefore we can choose much fewer steps along temporal dimension.

(code) Implementation of BTCS scheme solve_bdm()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
from scipy.sparse import diags

def solve_bdm(N: int, M: int, T: float):
    """
    solving 1D heat equation using BTCS scheme
        PDE: u_t = u_xx (\alpha^2 = 1)
        BCs: u(0, t) = u(1, t) = 0
        ICs: u(x, 0) = x - x**2

    args:
        - N, M  : number of collocation points
            in spacial and temporal dimension
        - T     : solving from t = 0 to T
    """
    # constructing the grid
    dx = 1 / (N - 1) # 0 <= x <= 1
    dt = T / (M - 1) # 0 < t <= T

    r = dt/dx**2 # (alpha = 1)

    # construct A:
    A = diags([-r, 1 + 2 * r, -r], [-1, 0, 1], shape=(N, N)).toarray()
    A[0, :] = 0
    A[-1, :] = 0
    A[0, 0] = 1
    A[-1,-1] = 1

    A_inv = np.linalg.inv(A)

    # approximate the result 
    U = np.zeros((N, M)) # already satisfied the BCs

    # IC impose initial condition
    x_grid = np.linspace(0, 1, N)
    ic = lambda x: np.sin(2 * np.pi * x) 
    U[:, 0] = np.vectorize(ic)(x_grid)

    for m in range(1, M):
        U[:, m] = A_inv @ U[:, m-1]

    return U

1
2
# solving the PDE with 100x100 grid from t=0 to t=0.2
U = solve_bdm(100, 100, .2)
imgimg2

We can see that results of FTCS and BTCS agree with each other, however, BTCS only using a \(100 \times 100\) grid while FTCS using \(100 \times 4000\) grid.

There is one more scheme for finite different methods which is Crank-Nicolson methods, which use central diffence method to estimate first order derivative !!

Physics Informed Neural Network

Let’s rewrite the heat equation in a more general form:

$$ \begin{equation} \begin{aligned} PDE: & & u_t = \alpha^2 u_{xx} & & 0 < x< 1 & & 0 < t < \infty \\ BCs: & & \begin{cases} u(0, t) = f_0(t)\\ u(1, t) = f_1(t) \end{cases} & & 0 < t < \infty \\ IC: & & u(x, 0) = \phi(t) & & 0 \leq x \leq 1 \end{aligned} \end{equation} $$

In our case, \(f_0(t) = f_1(t) = 0\) and \(\phi(x) = \sin 2\pi x\). PINN approximate the function \(u(x, t)\) by a neural network \(U_\theta(x, t)\), and then learn the networks parameters \(\theta\) by minimize the loss function:

$$ \begin{equation} \begin{aligned} \mathcal{L}(\theta) & = \frac{1}{N}\sum_{i=1}^N{[U_\theta(x_i, t_i) - u_i]^2} & \text{(Supervised loss)}\\ & + \frac{\lambda_j}{M_j} \sum_{j=1}^{M_j}{\bigg[ \frac{\partial U_\theta}{\partial t}- \alpha^2 \frac{\partial^2 U_\theta}{\partial x^2} \bigg](x_j, t_j)} & \text{(PDE residual)}\\ & + \frac{\lambda_k}{M_k} \sum_{k=1}^{M_k}{[(U_\theta(0, t_k) - f_0(t_k))^2 + (U_\theta(1, t_k) - f_1(t_k))^2]} & \text{(Boundary conditions)} \\ & + \frac{\lambda_h}{M_h} \sum_{h=1}^{M_h}{[U_\theta(x_h, 0) - \phi(x_h)]^2} & \text{(Initial condition)} \end{aligned} \end{equation} $$

  • The first term is the supervised loss, coinciding with statistical machine learning. Where \({(x_i, t_i, u_i)}_{i=1\cdots N}\) is the set of collocation points \((x_i, t_i)\), and value of \(u_i = u(x_i, t_i)\).

  • The second term is the PDE residual, where:

    • \(\frac{\partial U_\theta}{\partial t}\) is the partial derivative of the network \(U_\theta\) with respect to the time input \(t\)
    • Similarly, \(\frac{\partial^2 U_\theta}{\partial x^2}\) is the second derivative of the network with repsect to location \(x\).
  • Second and third terms are the initial and boundary conditions, given by equation (23).

We don’t necessarily have access to the first loss term. In the implementation bellow, I ignored the first loss term. For the remaining three loss terms:

  • PDE residual: \(x_j \sim \text{Uniform}(0, 1); t_j \sim\text{Uniform}(0, 0.2)\)
  • Boundary condition: \(t_k \sim \text{Uniform}(0, 0.2)\)
  • Initial condition: \(x_h \sim \text{Uniform}(0, 1)\)

Note: The code need some refactoring but it still works.

(code) JAX Implementation of PINN: train(), and compute_grid()
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import numpy as np
import click
import equinox as eqx
import jax 
import jax.numpy as jnp
import optax
import yaml
from matplotlib import pyplot as plt
from pydantic import BaseModel
from pydantic import PositiveInt
from tqdm import trange
from src.viz import plot_heatmap
from src.viz import animate 


# Plot config
plt.style.use('ggplot')


class PINNConfig(BaseModel):
    key: int 
    layers: list[PositiveInt]
    batch_ic_size: PositiveInt
    batch_bc_size: PositiveInt
    batch_interior_size: PositiveInt
    fn_path: str
    fn_path_gif: str


class U(eqx.Module):
    """
    Simple MLP taking (x, t) as input, return mlp(x, t)
    """

    layers: list #

    def __init__(self, layers: list[int], key):
        self.layers = []

        for _in, _out in zip(layers[:-1], layers[1:]):
            key, subkey = jax.random.split(key, 2)
            self.layers.append(eqx.nn.Linear(_in, _out, key=subkey))
    
    def __call__(self, x, t):
        """
        assuming x in R^{n x (d - 1)}, t in R
        """
        out = jnp.concatenate([x, t], axis=-1)

        for layer in self.layers[:-1]:
            out = layer(out)
            out = jax.nn.tanh(out)

        out = self.layers[-1](out)
        return jax.nn.tanh(out)


def interior_loss(u, x, t):
    """
    """
    # First and second derivative of u wrt x
    u_x = jax.grad(lambda x, t: jnp.squeeze(u(x, t)), argnums=0)
    u_xx = jax.grad(lambda x, t: jnp.squeeze(u_x(x, t)), argnums=0)
    u_t = jax.grad(lambda x, t: jnp.squeeze(u(x, t)), argnums=1)

    pde_resid = jax.vmap(u_t)(x, t) - jax.vmap(u_xx)(x, t)
    return jnp.mean(pde_resid**2)


def boundary_loss(u, x, t, f_bc: callable):
    """
    """
    # compute boundary value at each collocation point
    y = jax.vmap(f_bc)(x, t)
    y_hat = jax.vmap(u)(x, t)

    return jnp.mean((y - y_hat) ** 2)


def initial_condition_loss(u, x, t, f_ic):
    y = jax.vmap(f_ic)(x, t)
    y_hat = jax.vmap(u)(x, t)

    return jnp.mean((y - y_hat) ** 2)


def generate_interior_batch(key, n):
    """
    interior collocation points
    """
    # sample, discretizing interior point
    key, subkey = jax.random.split(key, 2)
    X = jax.random.uniform(subkey, shape=(n, 1), minval=1e-5, maxval=1-1e-5)

    key, subkey = jax.random.split(key, 2)
    T = jax.random.uniform(subkey, shape=(n, 1), minval=1e-5, maxval=.2)

    return X, T

def generate_ic_batch(key, n):
    """
    initial collocation points {(x_i, 0)}
    """
    # sample, discretizing interior point
    key, subkey = jax.random.split(key, 2)
    X = jax.random.uniform(subkey, shape=(n, 1), minval=1e-5, maxval=1-1e-5)

    key, subkey = jax.random.split(key, 2)
    T = jnp.zeros(shape=(n, 1))

    return X, T

def generate_bc_batch(key, n):
    """
    initial collocation points {(0/1, t_m)}
    """
    # sample, discretizing interior point
    key, subkey = jax.random.split(key, 2)
    X = jax.random.randint(subkey, shape=(n, 1), minval=0., maxval=2.)

    key, subkey = jax.random.split(key, 2)
    T = jax.random.uniform(subkey, shape=(n, 1), minval=1e-5, maxval=.2)

    return X, T


def loss_fn(u, x_i, t_i, x_ic, t_ic, x_bc, t_bc, f_ic, f_bc):
    """
    u: model
    x_i, t_i: interior collocation point
    x_ic, t_ic: initial points 
    x_bc, t_bc: boundar points 
    f_ic: initial condition
    f_bc: boundary condition 
    """
    return interior_loss(u, x_i, t_i) +\
            initial_condition_loss(u, x_ic, t_ic, f_ic) +\
            boundary_loss(u, x_bc, t_bc, f_bc)



def train(config: PINNConfig):
    key = jax.random.PRNGKey(config.key)
    key, subkey = jax.random.split(key, 2)

    # define the model
    u = U(config.layers, subkey)

    # define initial condition
    def f_ic(x, t):
        return jnp.sin(2 * jnp.pi * x)

    # define boundary condition
    def f_bc(x, t):
        return 0.

    # compute loss & loss gradient
    grad_loss_fn = jax.value_and_grad(loss_fn)
    
    @jax.jit
    def train_step(model, key, optim_state):
        # Generate data point
        ic_key, bc_key, i_key = jax.random.split(key, 3)
        x_ic, t_ic = generate_ic_batch(ic_key, config.batch_ic_size)
        x_bc, t_bc = generate_bc_batch(bc_key, config.batch_bc_size)
        x_i, t_i = generate_interior_batch(i_key, config.batch_interior_size)
    
        loss_val, grads = grad_loss_fn(
                model, x_i, t_i,
                x_ic, t_ic,
                x_bc, t_bc,
                f_ic, f_bc)
        updates, optim_state = optim.update(grads, optim_state) 
        new_model = eqx.apply_updates(model, updates)
    
        return loss_val, new_model, key, optim_state

    # optimizer
    optim = optax.adam(1e-3)
    optim_state = optim.init(u)

    losses = []
    pbar = trange(10000)

    for i in pbar:
        loss, u, key, optim_state = train_step(u, key, optim_state)
        pbar.set_description(f"Loss = {loss:.4f}")
        losses.append(loss)

    losses = jnp.array(losses)

    return u, losses


def compute_grid(model: eqx.Module, config: PINNConfig):
    """
    compute model output on 100 x 100 grid over following domain
    x in (0, 1)
    t in (0, 0.2)
    """

    def _compute(x, t):
        # transform scalar into a 1d vector x -> [x]
        x, t = jnp.expand_dims(x, 0),\
                jnp.expand_dims(t, 0)

        u = model(x, t)
        u = jnp.squeeze(u)

        return u

    x = jnp.linspace(0, 1, 100)
    t = jnp.linspace(0, 0.2, 100)

    xx, tt = np.meshgrid(x, t, sparse=True)
    U = np.vectorize(_compute)(xx, tt).T

    return U

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""
# Config
params:
  key: 0
  layers: [2, 128, 64, 32, 1]
  batch_ic_size: 16
  batch_bc_size: 16
  batch_interior_size: 128

"""
@click.command()
@click.option("--config", "-C", type=str, required=True, help="path/to/config")
def main(config: str):
    """
    Solving PDE using PINN
        PDE: u_t = alpha^2 u_{xx}       0 < x < 1; 0 < t < \infty
        BCs: u(0, t) = u(1, t) = 0      0 < t < \infty
        ICs: u(x, 0) = sin(2pi x)       0 < x < 1 
    
    Params:
        alpha^2 = 1 
    """
    with open(config, "r") as f:
        config = yaml.safe_load(f)
        config = PINNConfig(**config["params"])

    model, losses = train(config)
    U = compute_grid(model, config)

    plot_heatmap(U, config.fn_path)
    animate(U, config.fn_path_gif)

    return