Skip to content

Curvature¤

In this introductory example we compute various curvature quantities from the Riemannian metric on a manifold and compare against analytically known results. Everything here should be accessible with a basic knowledge of scientific computing and differential geometry. There are two Jax-specific transformations which we explain briefly below, for more detail please see the official guides.

  • jax.jit: Short for Just-in-Time compilation, this converts Jax Python functions to an optimised sequence of primitive operations which are then passed to some hardware accelerator. The output of jit is another function - usually one that executes significantly faster than the Python equivalent. The price to be paid is that the program logic of a jit-compatible function is constrained by the compiler, so you don't want (or need) to jit everything.
  • jax.vmap: Short for Vectorising Map, this transforms Jax Python functions written for execution on a single array element, to one which is automatically vectorised across the specified array axes. Again, program logic of a vmap-compatible function is restricted.

Jax transformations are compatible - you can jit a vmap-ed function and vice-versa. And that's pretty much all you need to know to understand this example!

While not a dependency of the package, the example notebooks require the installation of jupyter, run this locally if you haven't already.

pip install --upgrade jupyter notebook

import jax
from jax import random, jit, vmap
import jax.numpy as jnp

import os, time
import numpy as np

from functools import partial

jax.config.update("jax_enable_x64", True)

Manifold definition / point sampling¤

The routines in this library will work for an arbitrary real or complex manifold from which points may be sampled from. In this example, we consider complex projective space \(\mathbb{P}^n\). This the space of complex lines in \(\mathbb{C}^{n+1}\) which pass through the origin.

To sample from \(\mathbb{P}^n\), we use the fact that every complex line intersects the unit sphere along a circle, whose \(U(1)\) action we mod out, \(\mathbb{P}^n \simeq S^{2n+1} / U(1)\). This means that samples from the unit sphere, appropriately complexified, give samples in homogeneous coordinates on projective space. Here we set \(n=5\).

from cymyc.utils import math_utils

ambient_dim = 10
N = 10
seed = int(time.time()) # 42
rng = random.PRNGKey(seed)
rng, _rng = random.split(rng)

def S2np1_uniform(key, n_p, n, dtype=np.float64):
    """
    Sample `n_p` points uniformly on the unit sphere $S^{2n+1}$, treated as CP^n
    """
    # return random.uniform(key, (n,))*jnp.pi, random.uniform(key, (n,)) * 2 * jnp.pi
    x = random.normal(key, shape=(n_p, 2*(n+1)), dtype=dtype)
    x_norm = x / jnp.linalg.norm(x, axis=1, keepdims=True)
    sample = math_utils.to_complex(x_norm.reshape(-1, n+1, 2))

    return jnp.squeeze(sample)
Z = S2np1_uniform(_rng, N, ambient_dim)
Z
Array([[ 0.02736979-0.22389161j,  0.25014222-0.34731673j,
        -0.00753072-0.29925316j,  0.25562023+0.29378541j,
        -0.05679859-0.20567943j,  0.04690634+0.12982876j,
         0.15532516+0.14336839j,  0.20046198-0.38581777j,
        -0.26256972+0.28638481j, -0.12090856-0.13025232j,
         0.00636905+0.20921725j],
       [ 0.00833806-0.12314272j,  0.33281516+0.1516645j ,
        -0.1483373 -0.16471397j, -0.1052314 +0.2249917j ,
        -0.08208166-0.10561346j,  0.49581376-0.1674636j ,
        -0.00324378-0.21552658j,  0.24867794-0.19161808j,
        -0.14643599+0.17849217j,  0.41380261-0.19468098j,
        -0.17593772-0.09995519j],
       [ 0.27014354+0.43684002j, -0.20377776-0.11443033j,
        -0.21397595-0.00243355j,  0.1063651 +0.17664001j,
        -0.41157039-0.09987887j,  0.07792631+0.01783108j,
        -0.1407708 -0.14530487j, -0.00628455-0.20375251j,
        -0.26621185+0.2402969j , -0.12324527-0.3368604j ,
        -0.08962662+0.24437101j],
       [ 0.298288  +0.0090613j , -0.09293679+0.12098218j,
         0.15873836+0.30437709j,  0.09010836+0.24775444j,
         0.12491493+0.0599167j ,  0.06105556+0.02011806j,
        -0.06107647-0.08125648j,  0.39116699+0.09845587j,
        -0.06334902+0.10514544j, -0.37315632+0.03532256j,
         0.1744366 -0.5638871j ],
       [-0.03686343-0.11681997j, -0.04676194-0.14645691j,
         0.16080594-0.16616212j,  0.16180508+0.22286588j,
        -0.20503923+0.14494859j,  0.2959065 -0.2568473j ,
         0.28694812-0.21742779j, -0.29545629+0.15549756j,
        -0.43242921-0.02473469j,  0.20736562+0.35661549j,
        -0.10635665+0.07263177j],
       [ 0.37908512+0.22849334j,  0.012651  -0.37016357j,
         0.22781669-0.11919621j,  0.14999072+0.33529792j,
         0.17039864+0.07631351j,  0.31776397-0.03556206j,
         0.22090967+0.03842842j,  0.13607875+0.32494174j,
        -0.24973331+0.02489945j, -0.10798666-0.08427205j,
         0.21141219-0.16717926j],
       [ 0.1293235 +0.20853088j,  0.11784573-0.14996482j,
        -0.26778802+0.11306106j,  0.0662057 +0.04690227j,
         0.37957281-0.12004818j,  0.27164364+0.26007059j,
        -0.22151153-0.03442996j,  0.10460384+0.45146498j,
        -0.15981234-0.348539j  ,  0.08582687+0.19522458j,
        -0.20275629-0.11746974j],
       [ 0.2527872 +0.06671784j, -0.10444607+0.02656615j,
         0.08960571-0.07636178j,  0.08993089-0.34046328j,
        -0.20523979+0.51587811j, -0.33898292-0.14847849j,
         0.17718462+0.22234663j,  0.0333801 -0.30936953j,
         0.20401253-0.20578521j,  0.12323492-0.24393537j,
        -0.0247256 +0.00575643j],
       [-0.00236344-0.47434857j,  0.35553164-0.21107713j,
         0.10827872-0.14736322j,  0.0215133 +0.04707496j,
        -0.24504572-0.09207816j, -0.3192196 +0.02622722j,
         0.02600101+0.03722147j, -0.21314049-0.04537127j,
        -0.1855012 -0.37369281j, -0.31240302-0.2581296j ,
        -0.08234807-0.04673321j],
       [ 0.12391426-0.15099357j, -0.10282917-0.07848328j,
         0.26018045-0.39339781j, -0.12754509+0.16708262j,
         0.12648189-0.10537365j,  0.02287195+0.006204j  ,
         0.0445617 -0.0977747j , -0.37106113+0.25683421j,
        -0.13924258+0.10475447j,  0.45967485+0.29416271j,
         0.27354486-0.18053403j]], dtype=complex128)

We now use the scaling freedom in projective space to convert homogeneous coords on \(\mathbb{C}\mathbb{P}^n\), \(\left[z_0 : \cdots : z_n\right]\) to inhomogeneous coords in some local coordinate chart where \(z_{\alpha}\) nonzero, setting \(z_{\alpha} = 1\) and removing it from the coordinate description,

\[\left[z_0 : \cdots : z_n\right] \mapsto \left(\frac{z_0}{z_{\alpha}}, \ldots, \frac{z_{\alpha-1}}{z_{\alpha}}, \frac{z_{\alpha+1}}{z_{\alpha}}, \ldots, \frac{z_n}{z_{\alpha}}\right) \triangleq \zeta^{(\alpha)}~. \]
Z, _ = math_utils.rescale(Z)
z = vmap(math_utils._inhomogenize)(Z)
z.shape
(10, 10)

Metric definition¤

There is a natural metric on \(\mathbb{P}^n\) - the Fubini-Study metric. Viewing \(\mathbb{P}^n\) as the quotient \(S^{2n+1} / U(1)\), the Fubini_study metric is the unique metric such that the projection \(\pi: S^{2n+1} \rightarrow \mathbb{P}^n\) is a Riemannian submersion. In inhomogeneous coordinates,

\[ g_{\mu \bar{\nu}} = \frac{1}{\sigma}\left( \delta_{\mu \overline{\nu}} - \frac{\zeta_{\mu}\zeta_{\bar{\nu}}}{\sigma}\right), \quad \sigma = 1 + \sum_{m=1}^n \zeta_m\bar{\zeta}_m~. \]

The function below returns the FS metric in local coordinates. Note it requires a real input for autodiff to play nice, so we use the map

\[z = (z_1, \ldots, z_n) \in \mathbb{C}^n \mapsto (\Re(z_1), \ldots, \Re(z_n); \Im(z_1), \ldots, \Im(z_n)) \in \mathbb{R}^{2n}~.\]
def fubini_study_metric(p):
    """
    Returns FS metric in CP^n evaluated at `p`.
    Parameters
    ----------
        `p`     : 2*complex_dim real inhomogeneous coords at 
                  which metric matrix is evaluated. Shape [i].
    Returns
    ----------
        `g`     : Hermitian metric in CP^n, $g_{ij}$. Shape [i,j].
    """

    # Inhomogeneous coords
    complex_dim = p.shape[-1]//2
    zeta = jax.lax.complex(p[:complex_dim],
                           p[complex_dim:])
    zeta_bar = jnp.conjugate(zeta)
    zeta_sq = 1. + jnp.sum(zeta * zeta_bar)

    zeta_outer = jnp.einsum('...i,...j->...ij', zeta_bar, zeta)

    delta_mn = jnp.eye(complex_dim, dtype=jnp.complex64) 

    g_FS = jnp.divide(delta_mn * zeta_sq - zeta_outer, jnp.square(zeta_sq))

    return g_FS
p = math_utils.to_real(z)
g_FS = vmap(fubini_study_metric)(p)
g_FS.shape
(10, 10, 10)

We can benchmark execution times with and without jit-compilation - note the exact speedup will depend on the hardware available.

%%timeit
_ = vmap(fubini_study_metric)(p).block_until_ready()
6.09 ms ± 271 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%%timeit
_ = vmap(jit(fubini_study_metric))(p).block_until_ready()
629 μs ± 11.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

The Kähler potential¤

\(\mathbb{P}^n\) is a Kähler manifold - this imbues it with many special properties, one of them being that the metric is locally determined by a single real scalar function, the Kähler potential, \(\mathcal{K} \in C^{\infty}(\mathbb{P}^n)\).

\[\begin{align*} g_{\mu \bar{\nu }} &= \partial_{\mu}\overline{\partial}_{\bar{\nu}} \mathcal{K}~, \\ \mathcal{K} &= \log \left( 1+ \sum_{m=1}^n \left\vert \zeta_m \right\vert^2\right)~. \end{align*}\]

This is particularly important in the context of approximating metrics, as it allows one to reduce the problem to approximation of a single scalar function.

def fubini_study_potential(p):
    """
    Returns Kahler potential associated with the FS metric
    in CP^n evaluated at `p`.
    Parameters
    ----------
        `p`        : 2*complex_dim real inhomogeneous coords at 
                     which potential is evaluated. Shape [i].
    Returns
    ----------
        `phi`      : Kahler potential, real scalar. Shape [].  
    """
    zeta_sq = jnp.sum(p**2)
    return jnp.log(1. + zeta_sq)
from cymyc import curvature
_g_FS = vmap(curvature.del_z_bar_del_z, in_axes=(0,None))(p, fubini_study_potential)
_g_FS.shape
(10, 10, 10)
jnp.allclose(g_FS, _g_FS)
Array(True, dtype=bool)

Riemann tensor¤

Measures of curvature corresponding to a given metric tensor involve derivatives of the metric - if a function corresponding to the metric tensor is known, these may be easily computed numerically using autodiff. The most important curvature quantity is the Riemann curvature - the endomorphism-valued two-form that informs us about local curvature effects, \(\textsf{Riem} \in \Omega^2(X; \textsf{End}(T_X))\).

Schematically, the curvature tensor is given by taking two derivatives of the metric tensor w.r.t. the input coordinates. \(\Gamma\) below refers to the Levi-Civita connection in local coordinates,

\[\textsf{Riem} \sim \partial \Gamma + \Gamma \cdot \Gamma, \quad \Gamma \sim g^{-1} \partial g~.\]
riem = vmap(curvature.riemann_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric))

This involves two derivatives of a potentially expensive function, but is reasonably speedy for even \(10^4\) points, as we can test by benchmarking - in this case the function is already jit-ed at definition. Note nested jits are equivalent to a single jit.

%%timeit
riem = vmap(curvature.riemann_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric)).block_until_ready()
riem.shape
1.54 ms ± 65.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

rtk = partial(curvature.riemann_tensor_kahler, return_aux=True)
_, riem = vmap(rtk, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric))
riem = jnp.einsum('...abcd, ...ae->...becd', riem, g_FS)
riem.shape
(10, 10, 10, 10, 10)

First Bianchi identity¤

We form the Riemann tensor with all indices lowered using the musical isomorphism defined by the metric. The resulting tensor satisifies the following symmetries, as a consequence of the first Bianchi identity,

\[ \textsf{Riem}_{a\overline{b}c\overline{d}} = \textsf{Riem}_{a \overline{d} c \overline{b}} = \textsf{Riem}_{c \overline{b} a \overline{d}} = \textsf{Riem}_{c \overline{d} a \overline{b}}~.\]
riem_lower = jnp.einsum('...ibcd, ...ia->...bacd', riem, g_FS)
jnp.allclose(riem_lower, jnp.einsum('...abcd->...adcb', riem_lower))  # first equality
Array(False, dtype=bool)
jnp.allclose(riem_lower, jnp.einsum('...abcd->...cbad', riem_lower))  # second equality
Array(True, dtype=bool)
jnp.allclose(riem_lower, jnp.einsum('...abcd->...cdab', riem_lower))  # third equality
Array(False, dtype=bool)

Ricci curvature¤

Complex projective space is an Einstein manifold, meaning that the Fubini-Study metric on \(\mathbb{P}^n\) is proportional to the Ricci curvature. The Ricci curvature is another important measure of curvature derived from \(\textsf{Riem}\), which roughly measures the degree of volume distortion relative to Euclidean space as one travels along geodesics emanating from a given point.

\[\textsf{Ric} = \Lambda g~.\]

For \(\mathbb{P}^n\) the Einstein constant is \(\Lambda = n+1; \textsf{Ric} = (n+1) g_{FS}\).

The Ricci curvature is given, in local coordinates, as the trace of the endomorphism part of the Riemann curvature tensor,

\[ \textsf{Ric}_{\mu \bar{\nu}} \triangleq \textsf{Riem}^{\kappa}_{\; \kappa \mu \bar{\nu}} = \textsf{Riem}^{\kappa}_{\; \mu \kappa \bar{\nu}}~.\]
ricci = vmap(curvature.ricci_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric))
jnp.allclose(ricci, (ambient_dim + 1) * g_FS)
Array(True, dtype=bool)

This also means that the Ricci scalar, the trace of the Ricci curvature, should be, on \(\mathbb{P}^n\):

\[ \textsf{R} = n(n+1)~.\]
jnp.einsum('...ba, ...ab', jnp.linalg.inv(g_FS), ricci)
Array([110.+8.26087339e-16j, 110.-4.13292887e-16j, 110.+6.68198714e-16j,
       110.-6.83617305e-17j, 110.-1.03265506e-15j, 110.-1.35128649e-15j,
       110.+7.40287060e-16j, 110.+4.10574786e-16j, 110.+1.34577413e-15j,
       110.+1.86926177e-15j], dtype=complex128)