Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 10 additions & 125 deletions src/vmecpp/cpp/vmecpp/vmec/ideal_mhd_model/ideal_mhd_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "vmecpp/vmec/fourier_geometry/fourier_geometry.h"
#include "vmecpp/vmec/handover_storage/handover_storage.h"
#include "vmecpp/vmec/ideal_mhd_model/jacobian_kernel.h"
#include "vmecpp/vmec/ideal_mhd_model/metric_kernel.h"
#include "vmecpp/vmec/radial_partitioning/radial_partitioning.h"
#include "vmecpp/vmec/radial_profiles/radial_profiles.h"
#include "vmecpp/vmec/vmec_constants/vmec_algorithm_constants.h"
Expand Down Expand Up @@ -1238,131 +1239,15 @@ void IdealMhdModel::computeJacobian() {
}

void IdealMhdModel::computeMetricElements() {
// gsqrt
// guu, guv, gvv

// contributions from full-grid surface _i_nside j-th half-grid surface
int j0 = r_.nsMinF1;
for (int kl = 0; kl < s_.nZnT; ++kl) {
m_ls_.r1e_i[kl] = r1_e[(j0 - r_.nsMinF1) * s_.nZnT + kl];
m_ls_.r1o_i[kl] = r1_o[(j0 - r_.nsMinF1) * s_.nZnT + kl];
m_ls_.z1e_i[kl] = z1_e[(j0 - r_.nsMinF1) * s_.nZnT + kl];
m_ls_.z1o_i[kl] = z1_o[(j0 - r_.nsMinF1) * s_.nZnT + kl];
m_ls_.rue_i[kl] = ru_e[(j0 - r_.nsMinF1) * s_.nZnT + kl];
m_ls_.ruo_i[kl] = ru_o[(j0 - r_.nsMinF1) * s_.nZnT + kl];
m_ls_.zue_i[kl] = zu_e[(j0 - r_.nsMinF1) * s_.nZnT + kl];
m_ls_.zuo_i[kl] = zu_o[(j0 - r_.nsMinF1) * s_.nZnT + kl];
if (s_.lthreed) {
m_ls_.rve_i[kl] = rv_e[(j0 - r_.nsMinF1) * s_.nZnT + kl];
m_ls_.rvo_i[kl] = rv_o[(j0 - r_.nsMinF1) * s_.nZnT + kl];
m_ls_.zve_i[kl] = zv_e[(j0 - r_.nsMinF1) * s_.nZnT + kl];
m_ls_.zvo_i[kl] = zv_o[(j0 - r_.nsMinF1) * s_.nZnT + kl];
}
}

// s on inner full-grid pos
double sF_i =
m_p_.sqrtSF[r_.nsMinH - r_.nsMinF1] * m_p_.sqrtSF[r_.nsMinH - r_.nsMinF1];

for (int jH = r_.nsMinH; jH < r_.nsMaxH; ++jH) {
// s on outside full-grid pos
double sF_o =
m_p_.sqrtSF[jH + 1 - r_.nsMinF1] * m_p_.sqrtSF[jH + 1 - r_.nsMinF1];

// sqrt(s) on j-th half-grid pos
double sqrtSH = m_p_.sqrtSH[jH - r_.nsMinH];

for (int kl = 0; kl < s_.nZnT; ++kl) {
int iHalf = (jH - r_.nsMinH) * s_.nZnT + kl;

// Re-use this loop to compute Jacobian gsqrt=tau*R
// only tau needed to be checked for a sign change,
// so skip the last part where gsqrt is computed
// if a sign changed happened by computing it only here
// (which will only be reached when tau did not change sign).
gsqrt[iHalf] = tau[iHalf] * r12[iHalf];

// contributions from full-grid surface _o_utside j-th half-grid surface
double r1e_o = r1_e[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl];
double r1o_o = r1_o[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl];
double rue_o = ru_e[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl];
double ruo_o = ru_o[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl];
double zue_o = zu_e[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl];
double zuo_o = zu_o[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl];

// g_{\theta,\theta} is needed for both 2D and 3D cases
guu[iHalf] = 0.5 * ((m_ls_.rue_i[kl] * m_ls_.rue_i[kl] +
m_ls_.zue_i[kl] * m_ls_.zue_i[kl]) +
(rue_o * rue_o + zue_o * zue_o) +
sF_i * (m_ls_.ruo_i[kl] * m_ls_.ruo_i[kl] +
m_ls_.zuo_i[kl] * m_ls_.zuo_i[kl]) +
sF_o * (ruo_o * ruo_o + zuo_o * zuo_o)) +
sqrtSH * ((m_ls_.rue_i[kl] * m_ls_.ruo_i[kl] +
m_ls_.zue_i[kl] * m_ls_.zuo_i[kl]) +
(rue_o * ruo_o + zue_o * zuo_o));

// g_{\zeta,\zeta} reduces to R^2 in the 2D case, so compute this always
gvv[iHalf] = 0.5 * (m_ls_.r1e_i[kl] * m_ls_.r1e_i[kl] + r1e_o * r1e_o +
sF_i * m_ls_.r1o_i[kl] * m_ls_.r1o_i[kl] +
sF_o * r1o_o * r1o_o) +
sqrtSH * (m_ls_.r1e_i[kl] * m_ls_.r1o_i[kl] + r1e_o * r1o_o);

if (s_.lthreed) {
double rve_o = rv_e[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl];
double rvo_o = rv_o[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl];
double zve_o = zv_e[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl];
double zvo_o = zv_o[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl];

// g_{\theta,\zeta} is only needed for the 3D case
guv[iHalf] = 0.5 * ((m_ls_.rue_i[kl] * m_ls_.rve_i[kl] +
m_ls_.zue_i[kl] * m_ls_.zve_i[kl]) +
(rue_o * rve_o + zue_o * zve_o) +
sF_i * (m_ls_.ruo_i[kl] * m_ls_.rvo_i[kl] +
m_ls_.zuo_i[kl] * m_ls_.zvo_i[kl]) +
sF_o * (ruo_o * rvo_o + zuo_o * zvo_o) +
sqrtSH * ((m_ls_.rue_i[kl] * m_ls_.rvo_i[kl] +
m_ls_.zue_i[kl] * m_ls_.zvo_i[kl]) +
(rue_o * rvo_o + zue_o * zvo_o) +
(m_ls_.rve_i[kl] * m_ls_.ruo_i[kl] +
m_ls_.zve_i[kl] * m_ls_.zuo_i[kl]) +
(rve_o * ruo_o + zve_o * zuo_o)));

// compute remaining contribution for 3D to g_{\zeta,\zeta}
gvv[iHalf] += 0.5 * ((m_ls_.rve_i[kl] * m_ls_.rve_i[kl] +
m_ls_.zve_i[kl] * m_ls_.zve_i[kl]) +
(rve_o * rve_o + zve_o * zve_o) +
sF_i * (m_ls_.rvo_i[kl] * m_ls_.rvo_i[kl] +
m_ls_.zvo_i[kl] * m_ls_.zvo_i[kl]) +
sF_o * (rvo_o * rvo_o + zvo_o * zvo_o)) +
sqrtSH * ((m_ls_.rve_i[kl] * m_ls_.rvo_i[kl] +
m_ls_.zve_i[kl] * m_ls_.zvo_i[kl]) +
(rve_o * rvo_o + zve_o * zvo_o));

// hand over to next iteration of radial loop
// --> what was outside in this loop iteration will be inside for next
// half-grid location
m_ls_.rve_i[kl] = rve_o;
m_ls_.rvo_i[kl] = rvo_o;
m_ls_.zve_i[kl] = zve_o;
m_ls_.zvo_i[kl] = zvo_o;
}

// hand over to next iteration of radial loop
// --> what was outside in this loop iteration will be inside for next
// half-grid location
m_ls_.r1e_i[kl] = r1e_o;
m_ls_.r1o_i[kl] = r1o_o;
m_ls_.rue_i[kl] = rue_o;
m_ls_.ruo_i[kl] = ruo_o;
m_ls_.zue_i[kl] = zue_o;
m_ls_.zuo_i[kl] = zuo_o;
} // kl

// hand over to next iteration of radial loop
// --> what was outside in this loop iteration will be inside for next
// half-grid location
sF_i = sF_o;
} // jH
// gsqrt = tau * r12, and the metric elements guu, guv, gvv. Arithmetic in the
// shared, allocation-free kernel (metric_kernel.h), used by both the solver
// and the Enzyme autodiff path.
ComputeMetricElements(r1_e.data(), r1_o.data(), ru_e.data(), ru_o.data(),
zu_e.data(), zu_o.data(), rv_e.data(), rv_o.data(),
zv_e.data(), zv_o.data(), tau.data(), r12.data(),
m_p_.sqrtSF.data(), m_p_.sqrtSH.data(), s_.lthreed,
s_.nZnT, r_.nsMinF1, r_.nsMinH, r_.nsMaxH, gsqrt.data(),
guu.data(), guv.data(), gvv.data());
}

/**
Expand Down
84 changes: 84 additions & 0 deletions src/vmecpp/cpp/vmecpp/vmec/ideal_mhd_model/metric_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// SPDX-FileCopyrightText: 2024-present Proxima Fusion GmbH
// <info@proximafusion.com>
//
// SPDX-License-Identifier: MIT
#ifndef VMECPP_VMEC_IDEAL_MHD_MODEL_METRIC_KERNEL_H_
#define VMECPP_VMEC_IDEAL_MHD_MODEL_METRIC_KERNEL_H_

namespace vmecpp {

// Half-grid metric kernel: gsqrt = tau * r12 and the metric elements guu, guv,
// gvv from the full-grid geometry (and the Jacobian tau, r12 from
// ComputeHalfGridJacobian). guv and the 3D part of gvv are computed only when
// lthreed. Shared, allocation-free over flat buffers, between
// IdealMhdModel::computeMetricElements and the Enzyme autodiff path. Same
// indexing conventions as jacobian_kernel.h. sqrtSF is indexed jF - nsMinF1.
inline void ComputeMetricElements(
const double* __restrict r1e, const double* __restrict r1o,
const double* __restrict rue, const double* __restrict ruo,
const double* __restrict zue, const double* __restrict zuo,
const double* __restrict rve, const double* __restrict rvo,
const double* __restrict zve, const double* __restrict zvo,
const double* __restrict tau, const double* __restrict r12,
const double* __restrict sqrtSF, const double* __restrict sqrtSH,
bool lthreed, int nZnT, int nsMinF1, int nsMinH, int nsMaxH,
double* __restrict gsqrt, double* __restrict guu, double* __restrict guv,
double* __restrict gvv) {
for (int jH = nsMinH; jH < nsMaxH; ++jH) {
const double sF_i = sqrtSF[jH - nsMinF1] * sqrtSF[jH - nsMinF1];
const double sF_o = sqrtSF[jH + 1 - nsMinF1] * sqrtSF[jH + 1 - nsMinF1];
const double sH = sqrtSH[jH - nsMinH];
for (int kl = 0; kl < nZnT; ++kl) {
const int i_in = (jH - nsMinF1) * nZnT + kl;
const int i_out = (jH + 1 - nsMinF1) * nZnT + kl;
const int ih = (jH - nsMinH) * nZnT + kl;

const double r1e_i = r1e[i_in], r1e_o = r1e[i_out];
const double r1o_i = r1o[i_in], r1o_o = r1o[i_out];
const double rue_i = rue[i_in], rue_o = rue[i_out];
const double ruo_i = ruo[i_in], ruo_o = ruo[i_out];
const double zue_i = zue[i_in], zue_o = zue[i_out];
const double zuo_i = zuo[i_in], zuo_o = zuo[i_out];

gsqrt[ih] = tau[ih] * r12[ih];

guu[ih] = 0.5 * ((rue_i * rue_i + zue_i * zue_i) +
(rue_o * rue_o + zue_o * zue_o) +
sF_i * (ruo_i * ruo_i + zuo_i * zuo_i) +
sF_o * (ruo_o * ruo_o + zuo_o * zuo_o)) +
sH * ((rue_i * ruo_i + zue_i * zuo_i) +
(rue_o * ruo_o + zue_o * zuo_o));

gvv[ih] = 0.5 * (r1e_i * r1e_i + r1e_o * r1e_o + sF_i * r1o_i * r1o_i +
sF_o * r1o_o * r1o_o) +
sH * (r1e_i * r1o_i + r1e_o * r1o_o);

if (lthreed) {
const double rve_i = rve[i_in], rve_o = rve[i_out];
const double rvo_i = rvo[i_in], rvo_o = rvo[i_out];
const double zve_i = zve[i_in], zve_o = zve[i_out];
const double zvo_i = zvo[i_in], zvo_o = zvo[i_out];

guv[ih] = 0.5 * ((rue_i * rve_i + zue_i * zve_i) +
(rue_o * rve_o + zue_o * zve_o) +
sF_i * (ruo_i * rvo_i + zuo_i * zvo_i) +
sF_o * (ruo_o * rvo_o + zuo_o * zvo_o) +
sH * ((rue_i * rvo_i + zue_i * zvo_i) +
(rue_o * rvo_o + zue_o * zvo_o) +
(rve_i * ruo_i + zve_i * zuo_i) +
(rve_o * ruo_o + zve_o * zuo_o)));

gvv[ih] += 0.5 * ((rve_i * rve_i + zve_i * zve_i) +
(rve_o * rve_o + zve_o * zve_o) +
sF_i * (rvo_i * rvo_i + zvo_i * zvo_i) +
sF_o * (rvo_o * rvo_o + zvo_o * zvo_o)) +
sH * ((rve_i * rvo_i + zve_i * zvo_i) +
(rve_o * rvo_o + zve_o * zvo_o));
}
}
}
}

} // namespace vmecpp

#endif // VMECPP_VMEC_IDEAL_MHD_MODEL_METRIC_KERNEL_H_
Loading