Skip to content

Curvature¤

In this introductory example we compute various curvature quantities from the Riemannian metric on a manifold and compare against analytically known results. Everything here should be accessible with a basic knowledge of scientific computing and differential geometry. There are two Jax-specific transformations which we explain briefly below, for more detail please see the official guides.

  • jax.jit: Short for Just-in-Time compilation, this converts Jax Python functions to an optimised sequence of primitive operations which are then passed to some hardware accelerator. The output of jit is another function - usually one that executes significantly faster than the Python equivalent. The price to be paid is that the program logic of a jit-compatible function is constrained by the compiler, so you don't want (or need) to jit everything.
  • jax.vmap: Short for Vectorising Map, this transforms Jax Python functions written for execution on a single array element, to one which is automatically vectorised across the specified array axes. Again, program logic of a vmap-compatible function is restricted.

Jax transformations are compatible - you can jit a vmap-ed function and vice-versa. And that's pretty much all you need to know to understand this example!

While not a dependency of the package, the example notebooks require the installation of jupyter, run this locally if you haven't already.

pip install --upgrade jupyter notebook

import jax
from jax import random, jit, vmap
import jax.numpy as jnp

import os, time
import numpy as np

from functools import partial

jax.config.update("jax_enable_x64", True)

Manifold definition / point sampling¤

The routines in this library will work for an arbitrary real or complex manifold from which points may be sampled from. In this example, we consider complex projective space \(\mathbb{P}^n\). This the space of complex lines in \(\mathbb{C}^{n+1}\) which pass through the origin.

To sample from \(\mathbb{P}^n\), we use the fact that every complex line intersects the unit sphere along a circle, whose \(U(1)\) action we mod out, \(\mathbb{P}^n \simeq S^{2n+1} / U(1)\). This means that samples from the unit sphere, appropriately complexified, give samples in homogeneous coordinates on projective space. Here we set \(n=5\).

from cymyc.utils import math_utils

ambient_dim = 5
N = 10000
seed = int(time.time()) # 42
rng = random.PRNGKey(seed)
rng, _rng = random.split(rng)

def S2np1_uniform(key, n_p, n, dtype=np.float64):
    """
    Sample `n_p` points uniformly on the unit sphere $S^{2n+1}$, treated as CP^n
    """
    # return random.uniform(key, (n,))*jnp.pi, random.uniform(key, (n,)) * 2 * jnp.pi
    x = random.normal(key, shape=(n_p, 2*(n+1)), dtype=dtype)
    x_norm = x / jnp.linalg.norm(x, axis=1, keepdims=True)
    sample = math_utils.to_complex(x_norm.reshape(-1, n+1, 2))

    return jnp.squeeze(sample)
Z = S2np1_uniform(_rng, N, ambient_dim)
Z
Array([[-0.63836841+0.07540906j,  0.05975684+0.35032668j,
        -0.21817568-0.01460647j,  0.34166592+0.30850043j,
        -0.30371408+0.04905124j,  0.12019284-0.30279184j],
       [ 0.3338746 +0.24517945j,  0.00430112+0.03442162j,
         0.49296499+0.5634569j , -0.28264535-0.21595281j,
        -0.18750478+0.1247654j , -0.29667712+0.03804603j],
       [-0.09380982-0.40746492j,  0.00873018+0.16303933j,
        -0.44613948-0.50193767j, -0.11338982-0.34577994j,
        -0.01713955-0.38862749j,  0.17459208+0.18249288j],
       ...,
       [ 0.18246882-0.10256688j,  0.42200205-0.26619606j,
        -0.01144649+0.0401941j ,  0.31980423-0.30983694j,
         0.65976198-0.10411274j, -0.03887321-0.24409503j],
       [-0.13549991+0.3167525j ,  0.21550783+0.18345629j,
         0.4386935 +0.46204158j, -0.39535044+0.47843908j,
         0.04093444-0.05808759j,  0.06500915-0.02813306j],
       [-0.14109266+0.18389524j,  0.06828524-0.05392863j,
        -0.04639297+0.30262549j, -0.60838719-0.33054492j,
        -0.48948304-0.18501462j, -0.03999731-0.30025151j]],      dtype=complex128)

We now use the scaling freedom in projective space to convert homogeneous coords on \(\mathbb{C}\mathbb{P}^n\), \(\left[z_0 : \cdots : z_n\right]\) to inhomogeneous coords in some local coordinate chart where \(z_{\alpha}\) nonzero, setting \(z_{\alpha} = 1\) and removing it from the coordinate description,

\[\left[z_0 : \cdots : z_n\right] \mapsto \left(\frac{z_0}{z_{\alpha}}, \ldots, \frac{z_{\alpha-1}}{z_{\alpha}}, \frac{z_{\alpha+1}}{z_{\alpha}}, \ldots, \frac{z_n}{z_{\alpha}}\right) \triangleq \zeta^{(\alpha)}~. \]
Z, _ = math_utils.rescale(Z)
z = vmap(math_utils._inhomogenize)(Z)
z.shape
(10000, 5)

Metric definition¤

There is a natural metric on \(\mathbb{P}^n\) - the Fubini-Study metric. Viewing \(\mathbb{P}^n\) as the quotient \(S^{2n+1} / U(1)\), the Fubini_study metric is the unique metric such that the projection \(\pi: S^{2n+1} \rightarrow \mathbb{P}^n\) is a Riemannian submersion. In inhomogeneous coordinates,

\[ g_{\mu \bar{\nu}} = \frac{1}{\sigma}\left( \delta_{\mu \overline{\nu}} - \frac{\zeta_{\mu}\zeta_{\bar{\nu}}}{\sigma}\right), \quad \sigma = 1 + \sum_{m=1}^n \zeta_m\bar{\zeta}_m~. \]

The function below returns the FS metric in local coordinates. Note it requires a real input for autodiff to play nice, so we use the map

\[z = (z_1, \ldots, z_n) \in \mathbb{C}^n \mapsto (\Re(z_1), \ldots, \Re(z_n); \Im(z_1), \ldots, \Im(z_n)) \in \mathbb{R}^{2n}~.\]
def fubini_study_metric(p):
    """
    Returns FS metric in CP^n evaluated at `p`.
    Parameters
    ----------
        `p`     : 2*complex_dim real inhomogeneous coords at 
                  which metric matrix is evaluated. Shape [i].
    Returns
    ----------
        `g`     : Hermitian metric in CP^n, $g_{ij}$. Shape [i,j].
    """

    # Inhomogeneous coords
    complex_dim = p.shape[-1]//2
    zeta = jax.lax.complex(p[:complex_dim],
                           p[complex_dim:])
    zeta_bar = jnp.conjugate(zeta)
    zeta_sq = 1. + jnp.sum(zeta * zeta_bar)

    zeta_outer = jnp.einsum('...i,...j->...ij', zeta_bar, zeta)

    delta_mn = jnp.eye(complex_dim, dtype=jnp.complex64) 

    g_FS = jnp.divide(delta_mn * zeta_sq - zeta_outer, jnp.square(zeta_sq))

    return g_FS
p = math_utils.to_real(z)
g_FS = vmap(fubini_study_metric)(p)
g_FS.shape
(10000, 5, 5)

We can benchmark execution times with and without jit-compilation - note the exact speedup will depend on the hardware available.

%%timeit
_ = vmap(fubini_study_metric)(p).block_until_ready()
3.06 ms ± 269 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%%timeit
_ = vmap(jit(fubini_study_metric))(p).block_until_ready()
877 μs ± 4.79 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

The Kähler potential¤

\(\mathbb{P}^n\) is a Kähler manifold - this imbues it with many special properties, one of them being that the metric is locally determined by a single real scalar function, the Kähler potential, \(\mathcal{K} \in C^{\infty}(\mathbb{P}^n)\).

\[\begin{align*} g_{\mu \bar{\nu }} &= \partial_{\mu}\overline{\partial}_{\bar{\nu}} \mathcal{K}~, \\ \mathcal{K} &= \log \left( 1+ \sum_{m=1}^n \left\vert \zeta_m \right\vert^2\right)~. \end{align*}\]

This is particularly important in the context of approximating metrics, as it allows one to reduce the problem to approximation of a single scalar function.

def fubini_study_potential(p):
    """
    Returns Kahler potential associated with the FS metric
    in CP^n evaluated at `p`.
    Parameters
    ----------
        `p`        : 2*complex_dim real inhomogeneous coords at 
                     which potential is evaluated. Shape [i].
    Returns
    ----------
        `phi`      : Kahler potential, real scalar. Shape [].  
    """
    zeta_sq = jnp.sum(p**2)
    return jnp.log(1. + zeta_sq)
from cymyc import curvature
_g_FS = vmap(curvature.del_z_bar_del_z, in_axes=(0,None))(p, fubini_study_potential)
_g_FS.shape
(10000, 5, 5)
jnp.allclose(g_FS, _g_FS)
Array(True, dtype=bool)

Riemann tensor¤

Measures of curvature corresponding to a given metric tensor involve derivatives of the metric - if a function corresponding to the metric tensor is known, these may be easily computed numerically using autodiff. The most important curvature quantity is the Riemann curvature - the endomorphism-valued two-form that informs us about local curvature effects, \(\textsf{Riem} \in \Omega^2(X; \textsf{End}(T_X))\).

Schematically, the curvature tensor is given by taking two derivatives of the metric tensor w.r.t. the input coordinates. \(\Gamma\) below refers to the Levi-Civita connection in local coordinates,

\[\textsf{Riem} \sim \partial \Gamma + \Gamma \cdot \Gamma, \quad \Gamma \sim g^{-1} \partial g~.\]
riem = vmap(curvature.riemann_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric))

This involves two derivatives of a potentially expensive function, but is reasonably speedy for even \(10^4\) points, as we can test by benchmarking - in this case the function is already jit-ed at definition. Note nested jits are equivalent to a single jit.

%%timeit
riem = vmap(curvature.riemann_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric)).block_until_ready()
riem.shape
286 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

First Bianchi identity¤

We form the Riemann tensor with all indices lowered using the musical isomorphism defined by the metric. The resulting tensor satisifies the following symmetries, as a consequence of the first Bianchi identity,

\[ \textsf{Riem}_{a\overline{b}c\overline{d}} = \textsf{Riem}_{a \overline{d} c \overline{b}} = \textsf{Riem}_{c \overline{b} a \overline{d}} = \textsf{Riem}_{c \overline{d} a \overline{b}}~.\]
riem_lower = jnp.einsum('...ibcd, ...ia->...bacd', riem, g_FS)
jnp.allclose(riem_lower, jnp.einsum('...abcd->...adcb', riem_lower))  # first equality
Array(True, dtype=bool)
jnp.allclose(riem_lower, jnp.einsum('...abcd->...cbad', riem_lower))  # second equality
Array(True, dtype=bool)
jnp.allclose(riem_lower, jnp.einsum('...abcd->...cdab', riem_lower))  # third equality
Array(True, dtype=bool)

Ricci curvature¤

Complex projective space is an Einstein manifold, meaning that the Fubini-Study metric on \(\mathbb{P}^n\) is proportional to the Ricci curvature. The Ricci curvature is another important measure of curvature derived from \(\textsf{Riem}\), which roughly measures the degree of volume distortion relative to Euclidean space as one travels along geodesics emanating from a given point.

\[\textsf{Ric} = \Lambda g~.\]

For \(\mathbb{P}^n\) the Einstein constant is \(\Lambda = n+1; \textsf{Ric} = (n+1) g_{FS}\).

The Ricci curvature is given, in local coordinates, as the trace of the endomorphism part of the Riemann curvature tensor,

\[ \textsf{Ric}_{\mu \bar{\nu}} \triangleq \textsf{Riem}^{\kappa}_{\; \kappa \mu \bar{\nu}} = \textsf{Riem}^{\kappa}_{\; \mu \kappa \bar{\nu}}~.\]
ricci = vmap(curvature.ricci_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric))
jnp.allclose(ricci, (ambient_dim + 1) * g_FS)
Array(True, dtype=bool)

This also means that the Ricci scalar, the trace of the Ricci curvature, should be, on \(\mathbb{P}^n\):

\[ \textsf{R} = n(n+1)~.\]
jnp.einsum('...ba, ...ab', jnp.linalg.inv(g_FS), ricci)
Array([30.-8.33123781e-16j, 30.+1.21746809e-16j, 30.+1.72612599e-15j, ...,
       30.-7.98853592e-17j, 30.+1.10371070e-16j, 30.+9.64313597e-17j],      dtype=complex128)