Curvature¤
In this introductory example we compute various curvature quantities from the Riemannian metric on a manifold and compare against analytically known results. Everything here should be accessible with a basic knowledge of scientific computing and differential geometry. There are two Jax-specific transformations which we explain briefly below, for more detail please see the official guides.
jax.jit
: Short for Just-in-Time compilation, this converts Jax Python functions to an optimised sequence of primitive operations which are then passed to some hardware accelerator. The output ofjit
is another function - usually one that executes significantly faster than the Python equivalent. The price to be paid is that the program logic of ajit
-compatible function is constrained by the compiler, so you don't want (or need) tojit
everything.jax.vmap
: Short for Vectorising Map, this transforms Jax Python functions written for execution on a single array element, to one which is automatically vectorised across the specified array axes. Again, program logic of avmap
-compatible function is restricted.
Jax transformations are compatible - you can jit
a vmap
-ed function and vice-versa. And that's pretty much all you need to know to understand this example!
While not a dependency of the package, the example notebooks require the installation of jupyter
, run this locally if you haven't already.
pip install --upgrade jupyter notebook
import jax
from jax import random, jit, vmap
import jax.numpy as jnp
import os, time
import numpy as np
from functools import partial
jax.config.update("jax_enable_x64", True)
Manifold definition / point sampling¤
The routines in this library will work for an arbitrary real or complex manifold from which points may be sampled from. In this example, we consider complex projective space \(\mathbb{P}^n\). This the space of complex lines in \(\mathbb{C}^{n+1}\) which pass through the origin.
To sample from \(\mathbb{P}^n\), we use the fact that every complex line intersects the unit sphere along a circle, whose \(U(1)\) action we mod out, \(\mathbb{P}^n \simeq S^{2n+1} / U(1)\). This means that samples from the unit sphere, appropriately complexified, give samples in homogeneous coordinates on projective space. Here we set \(n=5\).
from cymyc.utils import math_utils
ambient_dim = 5
N = 10000
seed = int(time.time()) # 42
rng = random.PRNGKey(seed)
rng, _rng = random.split(rng)
def S2np1_uniform(key, n_p, n, dtype=np.float64):
"""
Sample `n_p` points uniformly on the unit sphere $S^{2n+1}$, treated as CP^n
"""
# return random.uniform(key, (n,))*jnp.pi, random.uniform(key, (n,)) * 2 * jnp.pi
x = random.normal(key, shape=(n_p, 2*(n+1)), dtype=dtype)
x_norm = x / jnp.linalg.norm(x, axis=1, keepdims=True)
sample = math_utils.to_complex(x_norm.reshape(-1, n+1, 2))
return jnp.squeeze(sample)
Z = S2np1_uniform(_rng, N, ambient_dim)
Z
We now use the scaling freedom in projective space to convert homogeneous coords on \(\mathbb{C}\mathbb{P}^n\), \(\left[z_0 : \cdots : z_n\right]\) to inhomogeneous coords in some local coordinate chart where \(z_{\alpha}\) nonzero, setting \(z_{\alpha} = 1\) and removing it from the coordinate description,
Z, _ = math_utils.rescale(Z)
z = vmap(math_utils._inhomogenize)(Z)
z.shape
Metric definition¤
There is a natural metric on \(\mathbb{P}^n\) - the Fubini-Study metric. Viewing \(\mathbb{P}^n\) as the quotient \(S^{2n+1} / U(1)\), the Fubini_study metric is the unique metric such that the projection \(\pi: S^{2n+1} \rightarrow \mathbb{P}^n\) is a Riemannian submersion. In inhomogeneous coordinates,
The function below returns the FS metric in local coordinates. Note it requires a real input for autodiff to play nice, so we use the map
def fubini_study_metric(p):
"""
Returns FS metric in CP^n evaluated at `p`.
Parameters
----------
`p` : 2*complex_dim real inhomogeneous coords at
which metric matrix is evaluated. Shape [i].
Returns
----------
`g` : Hermitian metric in CP^n, $g_{ij}$. Shape [i,j].
"""
# Inhomogeneous coords
complex_dim = p.shape[-1]//2
zeta = jax.lax.complex(p[:complex_dim],
p[complex_dim:])
zeta_bar = jnp.conjugate(zeta)
zeta_sq = 1. + jnp.sum(zeta * zeta_bar)
zeta_outer = jnp.einsum('...i,...j->...ij', zeta_bar, zeta)
delta_mn = jnp.eye(complex_dim, dtype=jnp.complex64)
g_FS = jnp.divide(delta_mn * zeta_sq - zeta_outer, jnp.square(zeta_sq))
return g_FS
p = math_utils.to_real(z)
g_FS = vmap(fubini_study_metric)(p)
g_FS.shape
We can benchmark execution times with and without jit
-compilation - note the exact speedup will depend on the hardware available.
%%timeit
_ = vmap(fubini_study_metric)(p).block_until_ready()
%%timeit
_ = vmap(jit(fubini_study_metric))(p).block_until_ready()
The Kähler potential¤
\(\mathbb{P}^n\) is a Kähler manifold - this imbues it with many special properties, one of them being that the metric is locally determined by a single real scalar function, the Kähler potential, \(\mathcal{K} \in C^{\infty}(\mathbb{P}^n)\).
This is particularly important in the context of approximating metrics, as it allows one to reduce the problem to approximation of a single scalar function.
def fubini_study_potential(p):
"""
Returns Kahler potential associated with the FS metric
in CP^n evaluated at `p`.
Parameters
----------
`p` : 2*complex_dim real inhomogeneous coords at
which potential is evaluated. Shape [i].
Returns
----------
`phi` : Kahler potential, real scalar. Shape [].
"""
zeta_sq = jnp.sum(p**2)
return jnp.log(1. + zeta_sq)
from cymyc import curvature
_g_FS = vmap(curvature.del_z_bar_del_z, in_axes=(0,None))(p, fubini_study_potential)
_g_FS.shape
jnp.allclose(g_FS, _g_FS)
Riemann tensor¤
Measures of curvature corresponding to a given metric tensor involve derivatives of the metric - if a function corresponding to the metric tensor is known, these may be easily computed numerically using autodiff. The most important curvature quantity is the Riemann curvature - the endomorphism-valued two-form that informs us about local curvature effects, \(\textsf{Riem} \in \Omega^2(X; \textsf{End}(T_X))\).
Schematically, the curvature tensor is given by taking two derivatives of the metric tensor w.r.t. the input coordinates. \(\Gamma\) below refers to the Levi-Civita connection in local coordinates,
riem = vmap(curvature.riemann_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric))
This involves two derivatives of a potentially expensive function, but is reasonably speedy for even \(10^4\) points, as we can test by benchmarking - in this case the function is already jit
-ed at definition. Note nested jit
s are equivalent to a single jit
.
%%timeit
riem = vmap(curvature.riemann_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric)).block_until_ready()
riem.shape
First Bianchi identity¤
We form the Riemann tensor with all indices lowered using the musical isomorphism defined by the metric. The resulting tensor satisifies the following symmetries, as a consequence of the first Bianchi identity,
riem_lower = jnp.einsum('...ibcd, ...ia->...bacd', riem, g_FS)
jnp.allclose(riem_lower, jnp.einsum('...abcd->...adcb', riem_lower)) # first equality
jnp.allclose(riem_lower, jnp.einsum('...abcd->...cbad', riem_lower)) # second equality
jnp.allclose(riem_lower, jnp.einsum('...abcd->...cdab', riem_lower)) # third equality
Ricci curvature¤
Complex projective space is an Einstein manifold, meaning that the Fubini-Study metric on \(\mathbb{P}^n\) is proportional to the Ricci curvature. The Ricci curvature is another important measure of curvature derived from \(\textsf{Riem}\), which roughly measures the degree of volume distortion relative to Euclidean space as one travels along geodesics emanating from a given point.
For \(\mathbb{P}^n\) the Einstein constant is \(\Lambda = n+1; \textsf{Ric} = (n+1) g_{FS}\).
The Ricci curvature is given, in local coordinates, as the trace of the endomorphism part of the Riemann curvature tensor,
ricci = vmap(curvature.ricci_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric))
jnp.allclose(ricci, (ambient_dim + 1) * g_FS)
This also means that the Ricci scalar, the trace of the Ricci curvature, should be, on \(\mathbb{P}^n\):
jnp.einsum('...ba, ...ab', jnp.linalg.inv(g_FS), ricci)