Skip to content

Approximation of the Ricci-flat metric on a manifold¤

In this example notebook we consider a simple complex manifold \(X\), and approximate the metric tensor on \(X\) by optimisation of a variational objective derived from a PDE the metric tensor should solve.

In more detail, the manifold \(X\) is a Calabi-Yau threefold defined as a projective variety in \(\mathbb{P}^5\). This is the (mirror of the) intersection of two cubics in the ambient \(\mathbb{P}^5\).

We wish to approximate the unique Ricci-flat metric on \(X\) - that is, one whose associated Ricci curvature vanishes. We sample points from the manifold and obtain an approximation to the Ricci-flat metric via optimisation of an objective constructed from the Monge-Ampere equation.

import jax
from jax import random
import jax.numpy as jnp

import os, time
import numpy as np

from cymyc.utils.pointgen_cicy import PointGenerator 

Manifold definition¤

The manifold is explicitly defined as the zero locus of the following polynomials in the complex projective space \(\mathbb{P}^5\), where we let \(Z_i\) denote local coordinates in the ambient projective space;

\[\begin{align*} P_1 &= Z_0^3 + Z_1^3 + Z_2^3 - 3 \psi Z_3 Z_4 Z_5~, \\ P_2 &= Z_3^3 + Z_4^3 + Z_5^3 - 3 \psi Z_0 Z_1 Z_2~. \end{align*}\]

The single complex structure moduli direction corresponds to the trilinear polynomial deformations above, parameterised by \(\psi \in \mathbb{C}\). We'll choose the point \(\psi = 0.5\) in moduli space, which is away from singularities in the moduli space - see this article for more details.

# Choose value of moduli parameter psi
psi = 0.5

monomials_1 = np.asarray([
    [3, 0, 0, 0, 0, 0],
    [0, 3, 0, 0, 0, 0],
    [0, 0, 3, 0, 0, 0],
    [0, 0, 0, 1, 1, 1]], dtype=np.int64)

monomials_2 = np.asarray([
    [0, 0, 0, 3, 0, 0],
    [0, 0, 0, 0, 3, 0],
    [0, 0, 0, 0, 0, 3],
    [1, 1, 1, 0, 0, 0]], dtype=np.int64)

monomials = [monomials_1, monomials_2]

cy_dim = 3
kmoduli = np.ones(1)
ambient = np.array([5])
dim = 6

coeff_fn = lambda psi: [np.append(np.ones(3), -3.0*psi), np.append(np.ones(3), -3.0*psi)]
coefficients = coeff_fn(psi)
poly_data = (monomials, cy_dim, kmoduli, ambient)

Point sampling¤

We need to sample points from \(X \hookrightarrow \mathbb{P}^n\) according to a known distribution in order to evaluate integrals via Monte Carlo integration. One way of doing this is to utilise the identification \(\mathbb{P}^n \simeq S^{2n+1} / U(1)\).

First we sample uniformly from the sphere \(S^{2n+1}\), obtaining a uniformly distributed sample over \(\mathbb{P}^n\). Next, one constructs a line connecting pairs of points \(p,q \in \mathbb{P}^n\), and identifies the points of intersection with the zero locus \(\{P_i=0\}_i\). These points are distributed according to a quantifiable distribution \(dA\) on \(X\). We correct for the fact that these points are not uniformly distributed according to the canonical measure on \(X\) by computing the associated importance weights (Douglas et. al (2008)).

seed = 42
rng = random.PRNGKey(seed)
rng, pg_rng, init_rng = random.split(rng, 3)
dpath = "data/X33_demo"

n_p = 400000  # Number of training points
v_p = 200000  # Number of validation points

import warnings
warnings.filterwarnings('ignore')
pg_cicy = PointGenerator(rng, cy_dim, monomials, ambient, coefficients, kmoduli)
cicy_pts = pg_cicy.sample_intersect_cicy(init_rng, n_p + v_p)
pg_cicy.export(dpath, cicy_pts, n_p, v_p, psi, poly_data, coefficients)
Generating 600000 points ...

[Parallel(n_jobs=-1)]: Using backend MultiprocessingBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:   16.2s
[Parallel(n_jobs=-1)]: Done 150752 tasks      | elapsed:  1.1min
[Parallel(n_jobs=-1)]: Done 489872 tasks      | elapsed:  2.5min
[Parallel(n_jobs=-1)]: Done 603010 out of 603010 | elapsed:  2.9min finished

Max locus violation: 1.67721e-13
Using kmoduli, [1.]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:04<00:00,  2.97it/s]

Volume 9*t_0**3
Volume at chosen Kahler moduli 9.0
kappa: 0.0393404

Dataset metadata¤

Here we save the metadata associated with the generated points to disk.

from cymyc import dataloading
from cymyc.utils import gen_utils as utils
from cymyc.approx.default_config import config

class args(object):
    # specify training config. For more options, see `src/approx/default_config`
    name = 'X33_demo'
    learning_rate = 1e-4
    n_epochs = 128  # change as necessary!
    dataset = dpath
    batch_size = 1024
    n_units = [48,48,48,48]

# Override default arguments from config file with provided command line arguments
config = utils.override_default_args(args, config)
config = utils.read_metadata(config)  # load dataset metadata

np_rng = np.random.default_rng()
data_train, data_val, train_loader, val_loader, psi = dataloading.initialize_loaders_train(
    np_rng      = np_rng,
    data_path   = os.path.join(config.dataset, "dataset.npz"),
    batch_size  = config.batch_size)
Saving config file to experiments/X33_demo/X33_demo_METADATA.pkl
Dataset size: (400000, 12), kappa: 0.0393404
Vol[g]: 0.0277778, Vol[Ω]: 0.7060880

Ricci-flat metric optimisation¤

import optax
from functools import partial

from cymyc.approx import models
from cymyc.approx.train import create_train_state, train_step, callback

Model construction¤

Here we construct a spectral neural network to approximate the Ricci-flat metric. This constructed as an \(\partial\overline{\partial}\)-exact correction to the corresponding Kähler form of some reference metric, one which is easily computable (Larfors et. al. (2022)). $$ \tilde{\omega} = \omega_{\text{ref}} + \partial \overline{\partial} \phi~.$$ Thus the task reduces to finding a globally defined function \(\phi \in C^{\infty}(X)\) s.t. the metric corresponding to the approximating Kähler form \(\tilde{\omega}\) satisfies the Ricci-flat condition, which says that the volume form induced by the metric coincides with the canonical volume form, up to a constant function, $$ \omega^n \propto \Omega \wedge \overline{\Omega} ~.$$ This is a nonlinear PDE on \(X\) for the function \(\phi\) - we encode this condition into a variational objective, discretise this via a neural network ansatz, and optimise to obtain a parameterised function describing a approximately Ricci-flat \(\tilde{\omega}\). Note we are not using the Ricci curvature as the objective itself, as evaluation of third-order derivatives of a neural network is quite expensive - this is possible with our library though.

In more detail, our ansatz consists of a projection of the input data \(p \in \mathbb{P}^n\) into a \(\mathbb{C}^*\)-invariant form, via the mapping \(\alpha_{n}\colon \mathbb{P}^{n} \longrightarrow \mathbb{C}^{n+1,n+1}\), whose action on a general point \(p\in [Z_0\colon Z_1\colon \dots\colon Z_{n}]\in\mathbb{P}^{n}\) is defined as:

\[\alpha_{n}(p) = \left[\begin{matrix} \displaystyle \frac{Z_0 \overline{Z_0}}{|Z|^2} && \displaystyle\frac{Z_0 \overline{Z_1}}{|Z|^2} && \dots && \displaystyle\frac{Z_0 \overline{Z_{n}}}{|Z|^2} \\ \displaystyle\frac{Z_1 \overline{Z_0}}{|Z|^2} && \displaystyle\frac{Z_1 \overline{Z_1}}{|Z|^2} && \dots && \displaystyle\frac{Z_1 \overline{Z_{n}}}{|Z|^2} \\ \vdots && \vdots && \ddots && \vdots \\ \displaystyle \frac{Z_{n} \overline{Z_0}}{|Z|^2} && \displaystyle\frac{Z_{n} \overline{Z_1}}{|Z|^2} && \dots && \displaystyle\frac{Z_{n} \overline{Z_{n}}}{|Z|^2} \end{matrix}\right] ~.\]

This is followed by the conversion to real coordinates and the application of a standard feedforward neural network as below - this ensures the overall function learnt is a well-defined function on the ambient projective space.

spec_arch

Further note that,

  • The default architecture we use in all our experiments is a four-layer net with 48 units each and results are empirically insensitive to the exact choice of architecture/hyperparameters.
  • The spectral layer adapts to the size of the input - therefore a projective variety with a higher-dimensional description of local coordinates on the ambient space entails an architecture of higher complexity.

Below we define the model and optimiser.

key, _key = jax.random.split(rng)

metric_model = models.LearnedVector_spectral_nn_CICY(
    dim=dim, ambient=ambient, n_units=config.n_units)

g_FS_fn, g_correction_fn, *_ = models.helper_fns(config)

optimizer = optax.adamw(config.learning_rate)
metric_params, opt_state, init_rng = create_train_state(_key, metric_model, optimizer, data_dim=dim * 2)
# partial closure
metric_fn = partial(models.ddbar_phi_model, g_ref_fn=g_FS_fn, g_correction_fn=g_correction_fn)

t0 = time.time()
logger = utils.logger_setup('X33_demo', filepath=os.path.abspath(''))
logger.info(metric_model.tabulate(init_rng, jnp.ones([1, config.n_ambient_coords * 2])))
Compiling LearnedVector_spectral_nn.spectral_layer.

14:25:26 INFO - logger_setup: /home/jt796/github/cymyc/docs/examples
14:25:26 INFO - <module>: 
                     LearnedVector_spectral_nn_CICY Summary                     
┏━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃ path      module            inputs         outputs      params           ┃
┡━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│          │ LearnedVector_s… │ float64[1,12] │ float64[]   │                  │
├──────────┼──────────────────┼───────────────┼─────────────┼──────────────────┤
│ layers_0 │ Dense            │ float64[36]   │ float64[48] │ bias:            │
│          │                  │               │             │ float32[48]      │
│          │                  │               │             │ kernel:          │
│          │                  │               │             │ float32[36,48]   │
│          │                  │               │             │                  │
│          │                  │               │             │ 1,776 (7.1 KB)   │
├──────────┼──────────────────┼───────────────┼─────────────┼──────────────────┤
│ layers_1 │ Dense            │ float64[48]   │ float64[48] │ bias:            │
│          │                  │               │             │ float32[48]      │
│          │                  │               │             │ kernel:          │
│          │                  │               │             │ float32[48,48]   │
│          │                  │               │             │                  │
│          │                  │               │             │ 2,352 (9.4 KB)   │
├──────────┼──────────────────┼───────────────┼─────────────┼──────────────────┤
│ layers_2 │ Dense            │ float64[48]   │ float64[48] │ bias:            │
│          │                  │               │             │ float32[48]      │
│          │                  │               │             │ kernel:          │
│          │                  │               │             │ float32[48,48]   │
│          │                  │               │             │                  │
│          │                  │               │             │ 2,352 (9.4 KB)   │
├──────────┼──────────────────┼───────────────┼─────────────┼──────────────────┤
│ layers_3 │ Dense            │ float64[48]   │ float64[48] │ bias:            │
│          │                  │               │             │ float32[48]      │
│          │                  │               │             │ kernel:          │
│          │                  │               │             │ float32[48,48]   │
│          │                  │               │             │                  │
│          │                  │               │             │ 2,352 (9.4 KB)   │
├──────────┼──────────────────┼───────────────┼─────────────┼──────────────────┤
│ scalar   │ Dense            │ float64[48]   │ float64[1]  │ bias: float32[1] │
│          │                  │               │             │ kernel:          │
│          │                  │               │             │ float32[48,1]    │
│          │                  │               │             │                  │
│          │                  │               │             │ 49 (196 B)       │
├──────────┼──────────────────┼───────────────┼─────────────┼──────────────────┤
│                                                  Total  8,881 (35.5 KB)  │
└──────────┴──────────────────┴───────────────┴─────────────┴──────────────────┘
                                                                                
                       Total Parameters: 8,881 (35.5 KB)                        



Compiling LearnedVector_spectral_nn.spectral_layer.

Optimisation loop¤

This is fairly standard - note Jax is more bare-metal than other libraries, so we write the looping logic ourselves.

jit compilation introduces a delay the first time the train_step function is called, but executes quickly when called subsequently. We pay an initial up-front cost for compilation of Python functions into a form efficiently executable by an accelerator, which will be repaid during the execution itself.

import time
from tqdm import tqdm
from collections import defaultdict

eval_interval = 5
storage = defaultdict(list)

try:
    device = jax.devices('gpu')[0]
except:
    device = jax.devices('cpu')[0]

with jax.default_device(device):
    logger.info(f"Running on {device}")

    for epoch in range(config.n_epochs):
        if (epoch % eval_interval == 0):
            val_loader, val_data = dataloading.get_validation_data(val_loader, config.batch_size, data_val, np_rng)
            storage = callback(epoch, t0, 0, val_data, metric_params, metric_fn, g_FS_fn, config, storage, logger, mode='VAL')

        if epoch &gt; 0:
            train_loader = dataloading.data_loader(data_train, config.batch_size, np_rng)

        train_loader_it = tqdm(train_loader, desc=f"Epoch: {epoch}", total=data_train[0].shape[0]//config.batch_size,
                               colour='green', mininterval=0.1)
        for t, data in enumerate(train_loader_it):
            metric_params, opt_state, loss = train_step(data, metric_params, opt_state, metric_fn, optimizer, config.kappa)
            train_loader_it.set_postfix_str(f"loss: {loss:.5f}", refresh=False)

# save parameters to disk
utils.basic_ckpt(metric_params, opt_state, config.name, 'FIN')
utils.save_logs(storage, config.name, 'FIN')
14:25:26 INFO - <module>: Running on cuda:0

Compiling ddbar_phi_model
Compiling phi_head
Compiling LearnedVector_spectral_nn.spectral_layer.
Compiling ricci_measure

14:26:18 INFO - callback: [51.7s]: [VAL] | Epoch: 0 | Iter: 0 | chi_form: -147.8624+0.0102j | det_g: 0.0018 | einstein_norm: 4.0715 | kahler_loss: 0.0000 | monge_ampere_loss: 0.2222 | ricci_measure: 0.1323 | ricci_scalar: -2.6378 | ricci_tensor_norm: 0.1802 | sigma_measure: 0.3148 | vol_CY: 0.0279 | vol_Omega: 0.7042 | vol_loss: 0.0001
Epoch: 0:   0%|                                                                                                                                                                                                                                       | 0/390 [00:00<?, ?it/s]
Compiling ddbar_phi_model

Epoch: 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:05<00:00, 76.54it/s, loss: 0.09083]
Epoch: 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 572.91it/s, loss: 0.06886]
Epoch: 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 580.58it/s, loss: 0.04289]
Epoch: 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 583.24it/s, loss: 0.02926]
Epoch: 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 589.58it/s, loss: 0.02722]
14:26:28 INFO - callback: [61.9s]: [VAL] | Epoch: 5 | Iter: 0 | chi_form: -149.3161-0.0202j | det_g: 0.0017 | einstein_norm: 1.2555 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0270 | ricci_measure: 0.0435 | ricci_scalar: -0.2172 | ricci_tensor_norm: 0.0606 | sigma_measure: 0.0381 | vol_CY: 0.0278 | vol_Omega: 0.7086 | vol_loss: 0.0000
Epoch: 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 534.56it/s, loss: 0.02547]
Epoch: 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 561.00it/s, loss: 0.02498]
Epoch: 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 589.92it/s, loss: 0.02409]
Epoch: 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 562.07it/s, loss: 0.02171]
Epoch: 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 557.11it/s, loss: 0.02210]
14:26:34 INFO - callback: [67.2s]: [VAL] | Epoch: 10 | Iter: 0 | chi_form: -146.9219+0.0199j | det_g: 0.0018 | einstein_norm: 1.1907 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0226 | ricci_measure: 0.0381 | ricci_scalar: 0.0463 | ricci_tensor_norm: 0.0576 | sigma_measure: 0.0321 | vol_CY: 0.0277 | vol_Omega: 0.7050 | vol_loss: 0.0000
Epoch: 10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 542.18it/s, loss: 0.02192]
Epoch: 11: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 551.39it/s, loss: 0.02137]
Epoch: 12: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 560.51it/s, loss: 0.01998]
Epoch: 13: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 561.09it/s, loss: 0.02048]
Epoch: 14: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 567.97it/s, loss: 0.01986]
14:26:39 INFO - callback: [72.7s]: [VAL] | Epoch: 15 | Iter: 0 | chi_form: -143.9541+0.0050j | det_g: 0.0017 | einstein_norm: 1.1205 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0191 | ricci_measure: 0.0373 | ricci_scalar: 0.1093 | ricci_tensor_norm: 0.0530 | sigma_measure: 0.0270 | vol_CY: 0.0279 | vol_Omega: 0.7076 | vol_loss: 0.0001
Epoch: 15: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 548.10it/s, loss: 0.01922]
Epoch: 16: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 577.24it/s, loss: 0.01867]
Epoch: 17: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 575.49it/s, loss: 0.01838]
Epoch: 18: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 568.85it/s, loss: 0.01876]
Epoch: 19: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 568.61it/s, loss: 0.01843]
14:26:44 INFO - callback: [78.0s]: [VAL] | Epoch: 20 | Iter: 0 | chi_form: -145.8684+0.0103j | det_g: 0.0017 | einstein_norm: 1.0711 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0172 | ricci_measure: 0.0343 | ricci_scalar: -0.0931 | ricci_tensor_norm: 0.0502 | sigma_measure: 0.0245 | vol_CY: 0.0276 | vol_Omega: 0.7030 | vol_loss: 0.0001
Epoch: 20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 521.85it/s, loss: 0.01772]
Epoch: 21: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 531.52it/s, loss: 0.01840]
Epoch: 22: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 529.54it/s, loss: 0.01716]
Epoch: 23: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 554.46it/s, loss: 0.01619]
Epoch: 24: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 573.31it/s, loss: 0.01558]
14:26:50 INFO - callback: [83.5s]: [VAL] | Epoch: 25 | Iter: 0 | chi_form: -141.6123-0.0002j | det_g: 0.0017 | einstein_norm: 1.1405 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0162 | ricci_measure: 0.0370 | ricci_scalar: -0.0062 | ricci_tensor_norm: 0.0506 | sigma_measure: 0.0229 | vol_CY: 0.0278 | vol_Omega: 0.7065 | vol_loss: 0.0000
Epoch: 25: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 527.47it/s, loss: 0.01658]
Epoch: 26: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 529.48it/s, loss: 0.01555]
Epoch: 27: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 555.86it/s, loss: 0.01646]
Epoch: 28: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 573.49it/s, loss: 0.01629]
Epoch: 29: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 566.63it/s, loss: 0.01639]
14:26:55 INFO - callback: [89.0s]: [VAL] | Epoch: 30 | Iter: 0 | chi_form: -136.6402-0.0538j | det_g: 0.0017 | einstein_norm: 1.0027 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0154 | ricci_measure: 0.0323 | ricci_scalar: 0.0722 | ricci_tensor_norm: 0.0476 | sigma_measure: 0.0218 | vol_CY: 0.0277 | vol_Omega: 0.7025 | vol_loss: 0.0001
Epoch: 30: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 532.89it/s, loss: 0.01503]
Epoch: 31: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 538.50it/s, loss: 0.01551]
Epoch: 32: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 566.12it/s, loss: 0.01464]
Epoch: 33: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 579.78it/s, loss: 0.01483]
Epoch: 34: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 582.54it/s, loss: 0.01501]
14:27:01 INFO - callback: [94.4s]: [VAL] | Epoch: 35 | Iter: 0 | chi_form: -153.3210+0.0054j | det_g: 0.0018 | einstein_norm: 1.1005 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0161 | ricci_measure: 0.0376 | ricci_scalar: -0.2312 | ricci_tensor_norm: 0.0474 | sigma_measure: 0.0227 | vol_CY: 0.0280 | vol_Omega: 0.7134 | vol_loss: 0.0002
Epoch: 35: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 540.23it/s, loss: 0.01448]
Epoch: 36: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 558.39it/s, loss: 0.01574]
Epoch: 37: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 572.88it/s, loss: 0.01526]
Epoch: 38: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 562.43it/s, loss: 0.01472]
Epoch: 39: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 539.29it/s, loss: 0.01496]
14:27:06 INFO - callback: [99.8s]: [VAL] | Epoch: 40 | Iter: 0 | chi_form: -146.8696-0.0111j | det_g: 0.0017 | einstein_norm: 1.0653 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0155 | ricci_measure: 0.0322 | ricci_scalar: 0.1571 | ricci_tensor_norm: 0.0465 | sigma_measure: 0.0218 | vol_CY: 0.0278 | vol_Omega: 0.7052 | vol_loss: 0.0000
Epoch: 40: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 568.74it/s, loss: 0.01487]
Epoch: 41: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 565.72it/s, loss: 0.01568]
Epoch: 42: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 566.88it/s, loss: 0.01412]
Epoch: 43: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 575.04it/s, loss: 0.01551]
Epoch: 44: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 560.32it/s, loss: 0.01417]
14:27:11 INFO - callback: [105.2s]: [VAL] | Epoch: 45 | Iter: 0 | chi_form: -153.1961+0.0116j | det_g: 0.0017 | einstein_norm: 1.0842 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0153 | ricci_measure: 0.0343 | ricci_scalar: -0.1280 | ricci_tensor_norm: 0.0449 | sigma_measure: 0.0218 | vol_CY: 0.0279 | vol_Omega: 0.7093 | vol_loss: 0.0001
Epoch: 45: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 547.89it/s, loss: 0.01515]
Epoch: 46: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 551.31it/s, loss: 0.01372]
Epoch: 47: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 573.33it/s, loss: 0.01479]
Epoch: 48: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 559.24it/s, loss: 0.01452]
Epoch: 49: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 572.21it/s, loss: 0.01526]
14:27:17 INFO - callback: [110.6s]: [VAL] | Epoch: 50 | Iter: 0 | chi_form: -132.6214+0.0156j | det_g: 0.0017 | einstein_norm: 0.9498 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0137 | ricci_measure: 0.0310 | ricci_scalar: 0.3576 | ricci_tensor_norm: 0.0425 | sigma_measure: 0.0192 | vol_CY: 0.0277 | vol_Omega: 0.7025 | vol_loss: 0.0001
Epoch: 50: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 564.28it/s, loss: 0.01403]
Epoch: 51: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 558.20it/s, loss: 0.01402]
Epoch: 52: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 569.94it/s, loss: 0.01399]
Epoch: 53: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 562.36it/s, loss: 0.01385]
Epoch: 54: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 562.99it/s, loss: 0.01393]
14:27:22 INFO - callback: [115.9s]: [VAL] | Epoch: 55 | Iter: 0 | chi_form: -141.8868+0.0201j | det_g: 0.0017 | einstein_norm: 1.0203 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0144 | ricci_measure: 0.0272 | ricci_scalar: 0.1503 | ricci_tensor_norm: 0.0422 | sigma_measure: 0.0203 | vol_CY: 0.0278 | vol_Omega: 0.7064 | vol_loss: 0.0000
Epoch: 55: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 572.05it/s, loss: 0.01462]
Epoch: 56: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 584.43it/s, loss: 0.01446]
Epoch: 57: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 567.86it/s, loss: 0.01493]
Epoch: 58: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 565.89it/s, loss: 0.01497]
Epoch: 59: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 563.96it/s, loss: 0.01303]
14:27:28 INFO - callback: [121.2s]: [VAL] | Epoch: 60 | Iter: 0 | chi_form: -142.2032+0.0036j | det_g: 0.0017 | einstein_norm: 0.9002 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0140 | ricci_measure: 0.0273 | ricci_scalar: 0.0055 | ricci_tensor_norm: 0.0406 | sigma_measure: 0.0198 | vol_CY: 0.0278 | vol_Omega: 0.7060 | vol_loss: 0.0000
Epoch: 60: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 550.94it/s, loss: 0.01336]
Epoch: 61: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 547.94it/s, loss: 0.01393]
Epoch: 62: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 545.21it/s, loss: 0.01484]
Epoch: 63: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 548.59it/s, loss: 0.01385]
Epoch: 64: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 544.52it/s, loss: 0.01368]
14:27:33 INFO - callback: [126.7s]: [VAL] | Epoch: 65 | Iter: 0 | chi_form: -150.3081+0.0003j | det_g: 0.0017 | einstein_norm: 0.9176 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0136 | ricci_measure: 0.0289 | ricci_scalar: -0.1310 | ricci_tensor_norm: 0.0419 | sigma_measure: 0.0195 | vol_CY: 0.0277 | vol_Omega: 0.7054 | vol_loss: 0.0001
Epoch: 65: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 562.41it/s, loss: 0.01326]
Epoch: 66: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 556.48it/s, loss: 0.01411]
Epoch: 67: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 554.83it/s, loss: 0.01347]
Epoch: 68: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 549.33it/s, loss: 0.01386]
Epoch: 69: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 559.46it/s, loss: 0.01356]
14:27:38 INFO - callback: [132.1s]: [VAL] | Epoch: 70 | Iter: 0 | chi_form: -133.0434-0.0148j | det_g: 0.0018 | einstein_norm: 1.0014 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0133 | ricci_measure: 0.0276 | ricci_scalar: 0.2083 | ricci_tensor_norm: 0.0416 | sigma_measure: 0.0187 | vol_CY: 0.0278 | vol_Omega: 0.7047 | vol_loss: 0.0000
Epoch: 70: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 544.78it/s, loss: 0.01418]
Epoch: 71: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 557.75it/s, loss: 0.01356]
Epoch: 72: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 567.88it/s, loss: 0.01371]
Epoch: 73: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 556.92it/s, loss: 0.01426]
Epoch: 74: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 557.79it/s, loss: 0.01390]
14:27:44 INFO - callback: [137.5s]: [VAL] | Epoch: 75 | Iter: 0 | chi_form: -144.0889+0.0023j | det_g: 0.0018 | einstein_norm: 0.9646 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0145 | ricci_measure: 0.0259 | ricci_scalar: -0.0402 | ricci_tensor_norm: 0.0418 | sigma_measure: 0.0207 | vol_CY: 0.0277 | vol_Omega: 0.7032 | vol_loss: 0.0001
Epoch: 75: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 545.42it/s, loss: 0.01353]
Epoch: 76: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 567.56it/s, loss: 0.01373]
Epoch: 77: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 563.65it/s, loss: 0.01263]
Epoch: 78: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 583.76it/s, loss: 0.01417]
Epoch: 79: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 577.58it/s, loss: 0.01379]
14:27:49 INFO - callback: [142.8s]: [VAL] | Epoch: 80 | Iter: 0 | chi_form: -147.1411-0.0089j | det_g: 0.0017 | einstein_norm: 1.0180 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0142 | ricci_measure: 0.0285 | ricci_scalar: 0.0157 | ricci_tensor_norm: 0.0419 | sigma_measure: 0.0200 | vol_CY: 0.0279 | vol_Omega: 0.7100 | vol_loss: 0.0001
Epoch: 80: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 526.09it/s, loss: 0.01380]
Epoch: 81: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 548.99it/s, loss: 0.01327]
Epoch: 82: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 553.80it/s, loss: 0.01450]
Epoch: 83: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 548.24it/s, loss: 0.01356]
Epoch: 84: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 572.65it/s, loss: 0.01379]
14:27:55 INFO - callback: [148.3s]: [VAL] | Epoch: 85 | Iter: 0 | chi_form: -138.4507+0.0099j | det_g: 0.0017 | einstein_norm: 0.9589 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0136 | ricci_measure: 0.0249 | ricci_scalar: 0.0912 | ricci_tensor_norm: 0.0414 | sigma_measure: 0.0193 | vol_CY: 0.0278 | vol_Omega: 0.7053 | vol_loss: 0.0000
Epoch: 85: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 541.77it/s, loss: 0.01447]
Epoch: 86: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 556.41it/s, loss: 0.01459]
Epoch: 87: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 543.58it/s, loss: 0.01386]
Epoch: 88: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 556.47it/s, loss: 0.01305]
Epoch: 89: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 569.38it/s, loss: 0.01432]
14:28:00 INFO - callback: [153.8s]: [VAL] | Epoch: 90 | Iter: 0 | chi_form: -146.9192-0.0090j | det_g: 0.0017 | einstein_norm: 0.9270 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0141 | ricci_measure: 0.0327 | ricci_scalar: -0.1549 | ricci_tensor_norm: 0.0407 | sigma_measure: 0.0203 | vol_CY: 0.0276 | vol_Omega: 0.7013 | vol_loss: 0.0002
Epoch: 90: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 549.42it/s, loss: 0.01304]
Epoch: 91: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 553.62it/s, loss: 0.01467]
Epoch: 92: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 561.20it/s, loss: 0.01429]
Epoch: 93: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 569.88it/s, loss: 0.01469]
Epoch: 94: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 555.88it/s, loss: 0.01326]
14:28:05 INFO - callback: [159.2s]: [VAL] | Epoch: 95 | Iter: 0 | chi_form: -134.2157+0.0087j | det_g: 0.0017 | einstein_norm: 0.8699 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0125 | ricci_measure: 0.0267 | ricci_scalar: 0.1765 | ricci_tensor_norm: 0.0394 | sigma_measure: 0.0178 | vol_CY: 0.0276 | vol_Omega: 0.6998 | vol_loss: 0.0002
Epoch: 95: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 564.26it/s, loss: 0.01420]
Epoch: 96: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 569.56it/s, loss: 0.01387]
Epoch: 97: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 571.52it/s, loss: 0.01350]
Epoch: 98: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 573.66it/s, loss: 0.01259]
Epoch: 99: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 584.08it/s, loss: 0.01261]
14:28:11 INFO - callback: [164.5s]: [VAL] | Epoch: 100 | Iter: 0 | chi_form: -145.4254+0.0072j | det_g: 0.0017 | einstein_norm: 0.9189 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0136 | ricci_measure: 0.0279 | ricci_scalar: -0.0306 | ricci_tensor_norm: 0.0406 | sigma_measure: 0.0193 | vol_CY: 0.0279 | vol_Omega: 0.7102 | vol_loss: 0.0001
Epoch: 100: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 541.49it/s, loss: 0.01350]
Epoch: 101: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 557.09it/s, loss: 0.01390]
Epoch: 102: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 574.05it/s, loss: 0.01449]
Epoch: 103: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 565.91it/s, loss: 0.01397]
Epoch: 104: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 558.86it/s, loss: 0.01302]
14:28:16 INFO - callback: [169.9s]: [VAL] | Epoch: 105 | Iter: 0 | chi_form: -149.9376-0.0439j | det_g: 0.0018 | einstein_norm: 1.0005 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0140 | ricci_measure: 0.0331 | ricci_scalar: 0.0176 | ricci_tensor_norm: 0.0409 | sigma_measure: 0.0199 | vol_CY: 0.0278 | vol_Omega: 0.7065 | vol_loss: 0.0000
Epoch: 105: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 576.60it/s, loss: 0.01335]
Epoch: 106: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 590.34it/s, loss: 0.01331]
Epoch: 107: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 590.94it/s, loss: 0.01346]
Epoch: 108: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 588.17it/s, loss: 0.01293]
Epoch: 109: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 579.40it/s, loss: 0.01316]
14:28:21 INFO - callback: [175.1s]: [VAL] | Epoch: 110 | Iter: 0 | chi_form: -138.6156+0.0104j | det_g: 0.0017 | einstein_norm: 1.0735 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0131 | ricci_measure: 0.0251 | ricci_scalar: 0.0507 | ricci_tensor_norm: 0.0405 | sigma_measure: 0.0185 | vol_CY: 0.0278 | vol_Omega: 0.7051 | vol_loss: 0.0000
Epoch: 110: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 528.55it/s, loss: 0.01295]
Epoch: 111: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 552.33it/s, loss: 0.01392]
Epoch: 112: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 546.67it/s, loss: 0.01305]
Epoch: 113: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 547.34it/s, loss: 0.01265]
Epoch: 114: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 558.66it/s, loss: 0.01343]
14:28:27 INFO - callback: [180.6s]: [VAL] | Epoch: 115 | Iter: 0 | chi_form: -143.7961+0.0100j | det_g: 0.0017 | einstein_norm: 0.9358 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0139 | ricci_measure: 0.0269 | ricci_scalar: 0.0594 | ricci_tensor_norm: 0.0403 | sigma_measure: 0.0196 | vol_CY: 0.0279 | vol_Omega: 0.7098 | vol_loss: 0.0001
Epoch: 115: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 542.10it/s, loss: 0.01215]
Epoch: 116: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 554.72it/s, loss: 0.01322]
Epoch: 117: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 566.17it/s, loss: 0.01286]
Epoch: 118: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 554.81it/s, loss: 0.01353]
Epoch: 119: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 549.84it/s, loss: 0.01315]
14:28:32 INFO - callback: [186.0s]: [VAL] | Epoch: 120 | Iter: 0 | chi_form: -141.6601-0.0063j | det_g: 0.0017 | einstein_norm: 0.8899 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0130 | ricci_measure: 0.0281 | ricci_scalar: 0.0299 | ricci_tensor_norm: 0.0384 | sigma_measure: 0.0183 | vol_CY: 0.0279 | vol_Omega: 0.7103 | vol_loss: 0.0002
Epoch: 120: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 548.59it/s, loss: 0.01400]
Epoch: 121: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 542.58it/s, loss: 0.01359]
Epoch: 122: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 538.37it/s, loss: 0.01296]
Epoch: 123: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 542.21it/s, loss: 0.01368]
Epoch: 124: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 551.13it/s, loss: 0.01362]
14:28:38 INFO - callback: [191.5s]: [VAL] | Epoch: 125 | Iter: 0 | chi_form: -145.2401-0.0202j | det_g: 0.0017 | einstein_norm: 0.8886 | kahler_loss: 0.0000 | monge_ampere_loss: 0.0130 | ricci_measure: 0.0244 | ricci_scalar: -0.0146 | ricci_tensor_norm: 0.0386 | sigma_measure: 0.0186 | vol_CY: 0.0276 | vol_Omega: 0.7028 | vol_loss: 0.0001
Epoch: 125: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 554.48it/s, loss: 0.01296]
Epoch: 126: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 559.61it/s, loss: 0.01254]
Epoch: 127: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:00<00:00, 554.74it/s, loss: 0.01302]

Sanity check¤

As a sanity check, we may verify that the resulting metric is approximately Ricci-flat by investigating the behaviour of the Ricci tensor \(\textsf{Ric}\) and Ricci scalar \(R\), defined in terms of the Riemann curvature tensor \(\textsf{Riem} \in \Omega^2(X; \textsf{End}(T_X))\), respectively, over training: \begin{align} \textsf{Ric}(X,Y) &= \textsf{Tr}\left(Z \mapsto \textsf{Riem}(Z,X)Y\right), \quad X, Y \in \Gamma(T_X)~,\ R &= \textsf{Tr}(\textsf{Ric})~. \end{align}

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
plt.rcParams.update({'font.size': 14})
# remove if no local tex installation
plt.rcParams['text.usetex'] = True 
plt.rcParams['text.latex.preamble'] = r'\usepackage[cm]{sfmath} \usepackage{amssymb} \usepackage{mathrsfs} \usepackage{amsmath}'
plt.rcParams['font.sans-serif'] = 'cm'
fig = plt.figure(figsize=(17,7))
gs=GridSpec(1,2)
ax1=fig.add_subplot(gs[0,0])
ax2=fig.add_subplot(gs[0,1])

S = np.abs(storage['ricci_scalar'])
n = 32
ax1.plot(np.arange(len(S))[:n], np.abs(S)[:n], c='royalblue')
ax1.set_xlabel(f'Epochs/{eval_interval}')
ax1.set_ylabel(r'$R$')
ax1.grid(True, 'both')

R = storage['ricci_tensor_norm']
ax2.plot(np.arange(len(R))[:n], np.abs(R)[:n], c='royalblue')
ax2.set_xlabel(f'Epochs/{eval_interval}')
ax2.set_ylabel(r'$\Vert \textsf{Ric} \Vert_2$')
ax2.grid(True, 'both')
No description has been provided for this image

Now that we have an approximation to the Ricci-flat metric on \(X\), we can use this to do geometry - e.g. measure lengths, areas, volumes, and study the spectrum of various differential operators. See the harmonic_forms example notebook, where we use the learnt metric to find the zero modes of the Laplacian on \(X\).