diff --git a/src/vmecpp/cpp/vmecpp/vmec/ideal_mhd_model/ideal_mhd_model.cc b/src/vmecpp/cpp/vmecpp/vmec/ideal_mhd_model/ideal_mhd_model.cc index 0644fb482..31535b47e 100644 --- a/src/vmecpp/cpp/vmecpp/vmec/ideal_mhd_model/ideal_mhd_model.cc +++ b/src/vmecpp/cpp/vmecpp/vmec/ideal_mhd_model/ideal_mhd_model.cc @@ -21,6 +21,7 @@ #include "vmecpp/vmec/fourier_geometry/fourier_geometry.h" #include "vmecpp/vmec/handover_storage/handover_storage.h" #include "vmecpp/vmec/ideal_mhd_model/jacobian_kernel.h" +#include "vmecpp/vmec/ideal_mhd_model/metric_kernel.h" #include "vmecpp/vmec/radial_partitioning/radial_partitioning.h" #include "vmecpp/vmec/radial_profiles/radial_profiles.h" #include "vmecpp/vmec/vmec_constants/vmec_algorithm_constants.h" @@ -1238,131 +1239,15 @@ void IdealMhdModel::computeJacobian() { } void IdealMhdModel::computeMetricElements() { - // gsqrt - // guu, guv, gvv - - // contributions from full-grid surface _i_nside j-th half-grid surface - int j0 = r_.nsMinF1; - for (int kl = 0; kl < s_.nZnT; ++kl) { - m_ls_.r1e_i[kl] = r1_e[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - m_ls_.r1o_i[kl] = r1_o[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - m_ls_.z1e_i[kl] = z1_e[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - m_ls_.z1o_i[kl] = z1_o[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - m_ls_.rue_i[kl] = ru_e[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - m_ls_.ruo_i[kl] = ru_o[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - m_ls_.zue_i[kl] = zu_e[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - m_ls_.zuo_i[kl] = zu_o[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - if (s_.lthreed) { - m_ls_.rve_i[kl] = rv_e[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - m_ls_.rvo_i[kl] = rv_o[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - m_ls_.zve_i[kl] = zv_e[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - m_ls_.zvo_i[kl] = zv_o[(j0 - r_.nsMinF1) * s_.nZnT + kl]; - } - } - - // s on inner full-grid pos - double sF_i = - m_p_.sqrtSF[r_.nsMinH - r_.nsMinF1] * m_p_.sqrtSF[r_.nsMinH - r_.nsMinF1]; - - for (int jH = r_.nsMinH; jH < r_.nsMaxH; ++jH) { - // s on outside full-grid pos - double sF_o = - m_p_.sqrtSF[jH + 1 - r_.nsMinF1] * m_p_.sqrtSF[jH + 1 - r_.nsMinF1]; - - // sqrt(s) on j-th half-grid pos - double sqrtSH = m_p_.sqrtSH[jH - r_.nsMinH]; - - for (int kl = 0; kl < s_.nZnT; ++kl) { - int iHalf = (jH - r_.nsMinH) * s_.nZnT + kl; - - // Re-use this loop to compute Jacobian gsqrt=tau*R - // only tau needed to be checked for a sign change, - // so skip the last part where gsqrt is computed - // if a sign changed happened by computing it only here - // (which will only be reached when tau did not change sign). - gsqrt[iHalf] = tau[iHalf] * r12[iHalf]; - - // contributions from full-grid surface _o_utside j-th half-grid surface - double r1e_o = r1_e[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl]; - double r1o_o = r1_o[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl]; - double rue_o = ru_e[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl]; - double ruo_o = ru_o[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl]; - double zue_o = zu_e[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl]; - double zuo_o = zu_o[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl]; - - // g_{\theta,\theta} is needed for both 2D and 3D cases - guu[iHalf] = 0.5 * ((m_ls_.rue_i[kl] * m_ls_.rue_i[kl] + - m_ls_.zue_i[kl] * m_ls_.zue_i[kl]) + - (rue_o * rue_o + zue_o * zue_o) + - sF_i * (m_ls_.ruo_i[kl] * m_ls_.ruo_i[kl] + - m_ls_.zuo_i[kl] * m_ls_.zuo_i[kl]) + - sF_o * (ruo_o * ruo_o + zuo_o * zuo_o)) + - sqrtSH * ((m_ls_.rue_i[kl] * m_ls_.ruo_i[kl] + - m_ls_.zue_i[kl] * m_ls_.zuo_i[kl]) + - (rue_o * ruo_o + zue_o * zuo_o)); - - // g_{\zeta,\zeta} reduces to R^2 in the 2D case, so compute this always - gvv[iHalf] = 0.5 * (m_ls_.r1e_i[kl] * m_ls_.r1e_i[kl] + r1e_o * r1e_o + - sF_i * m_ls_.r1o_i[kl] * m_ls_.r1o_i[kl] + - sF_o * r1o_o * r1o_o) + - sqrtSH * (m_ls_.r1e_i[kl] * m_ls_.r1o_i[kl] + r1e_o * r1o_o); - - if (s_.lthreed) { - double rve_o = rv_e[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl]; - double rvo_o = rv_o[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl]; - double zve_o = zv_e[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl]; - double zvo_o = zv_o[(jH + 1 - r_.nsMinF1) * s_.nZnT + kl]; - - // g_{\theta,\zeta} is only needed for the 3D case - guv[iHalf] = 0.5 * ((m_ls_.rue_i[kl] * m_ls_.rve_i[kl] + - m_ls_.zue_i[kl] * m_ls_.zve_i[kl]) + - (rue_o * rve_o + zue_o * zve_o) + - sF_i * (m_ls_.ruo_i[kl] * m_ls_.rvo_i[kl] + - m_ls_.zuo_i[kl] * m_ls_.zvo_i[kl]) + - sF_o * (ruo_o * rvo_o + zuo_o * zvo_o) + - sqrtSH * ((m_ls_.rue_i[kl] * m_ls_.rvo_i[kl] + - m_ls_.zue_i[kl] * m_ls_.zvo_i[kl]) + - (rue_o * rvo_o + zue_o * zvo_o) + - (m_ls_.rve_i[kl] * m_ls_.ruo_i[kl] + - m_ls_.zve_i[kl] * m_ls_.zuo_i[kl]) + - (rve_o * ruo_o + zve_o * zuo_o))); - - // compute remaining contribution for 3D to g_{\zeta,\zeta} - gvv[iHalf] += 0.5 * ((m_ls_.rve_i[kl] * m_ls_.rve_i[kl] + - m_ls_.zve_i[kl] * m_ls_.zve_i[kl]) + - (rve_o * rve_o + zve_o * zve_o) + - sF_i * (m_ls_.rvo_i[kl] * m_ls_.rvo_i[kl] + - m_ls_.zvo_i[kl] * m_ls_.zvo_i[kl]) + - sF_o * (rvo_o * rvo_o + zvo_o * zvo_o)) + - sqrtSH * ((m_ls_.rve_i[kl] * m_ls_.rvo_i[kl] + - m_ls_.zve_i[kl] * m_ls_.zvo_i[kl]) + - (rve_o * rvo_o + zve_o * zvo_o)); - - // hand over to next iteration of radial loop - // --> what was outside in this loop iteration will be inside for next - // half-grid location - m_ls_.rve_i[kl] = rve_o; - m_ls_.rvo_i[kl] = rvo_o; - m_ls_.zve_i[kl] = zve_o; - m_ls_.zvo_i[kl] = zvo_o; - } - - // hand over to next iteration of radial loop - // --> what was outside in this loop iteration will be inside for next - // half-grid location - m_ls_.r1e_i[kl] = r1e_o; - m_ls_.r1o_i[kl] = r1o_o; - m_ls_.rue_i[kl] = rue_o; - m_ls_.ruo_i[kl] = ruo_o; - m_ls_.zue_i[kl] = zue_o; - m_ls_.zuo_i[kl] = zuo_o; - } // kl - - // hand over to next iteration of radial loop - // --> what was outside in this loop iteration will be inside for next - // half-grid location - sF_i = sF_o; - } // jH + // gsqrt = tau * r12, and the metric elements guu, guv, gvv. Arithmetic in the + // shared, allocation-free kernel (metric_kernel.h), used by both the solver + // and the Enzyme autodiff path. + ComputeMetricElements(r1_e.data(), r1_o.data(), ru_e.data(), ru_o.data(), + zu_e.data(), zu_o.data(), rv_e.data(), rv_o.data(), + zv_e.data(), zv_o.data(), tau.data(), r12.data(), + m_p_.sqrtSF.data(), m_p_.sqrtSH.data(), s_.lthreed, + s_.nZnT, r_.nsMinF1, r_.nsMinH, r_.nsMaxH, gsqrt.data(), + guu.data(), guv.data(), gvv.data()); } /** diff --git a/src/vmecpp/cpp/vmecpp/vmec/ideal_mhd_model/metric_kernel.h b/src/vmecpp/cpp/vmecpp/vmec/ideal_mhd_model/metric_kernel.h new file mode 100644 index 000000000..d014ec606 --- /dev/null +++ b/src/vmecpp/cpp/vmecpp/vmec/ideal_mhd_model/metric_kernel.h @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: 2024-present Proxima Fusion GmbH +// +// +// SPDX-License-Identifier: MIT +#ifndef VMECPP_VMEC_IDEAL_MHD_MODEL_METRIC_KERNEL_H_ +#define VMECPP_VMEC_IDEAL_MHD_MODEL_METRIC_KERNEL_H_ + +namespace vmecpp { + +// Half-grid metric kernel: gsqrt = tau * r12 and the metric elements guu, guv, +// gvv from the full-grid geometry (and the Jacobian tau, r12 from +// ComputeHalfGridJacobian). guv and the 3D part of gvv are computed only when +// lthreed. Shared, allocation-free over flat buffers, between +// IdealMhdModel::computeMetricElements and the Enzyme autodiff path. Same +// indexing conventions as jacobian_kernel.h. sqrtSF is indexed jF - nsMinF1. +inline void ComputeMetricElements( + const double* __restrict r1e, const double* __restrict r1o, + const double* __restrict rue, const double* __restrict ruo, + const double* __restrict zue, const double* __restrict zuo, + const double* __restrict rve, const double* __restrict rvo, + const double* __restrict zve, const double* __restrict zvo, + const double* __restrict tau, const double* __restrict r12, + const double* __restrict sqrtSF, const double* __restrict sqrtSH, + bool lthreed, int nZnT, int nsMinF1, int nsMinH, int nsMaxH, + double* __restrict gsqrt, double* __restrict guu, double* __restrict guv, + double* __restrict gvv) { + for (int jH = nsMinH; jH < nsMaxH; ++jH) { + const double sF_i = sqrtSF[jH - nsMinF1] * sqrtSF[jH - nsMinF1]; + const double sF_o = sqrtSF[jH + 1 - nsMinF1] * sqrtSF[jH + 1 - nsMinF1]; + const double sH = sqrtSH[jH - nsMinH]; + for (int kl = 0; kl < nZnT; ++kl) { + const int i_in = (jH - nsMinF1) * nZnT + kl; + const int i_out = (jH + 1 - nsMinF1) * nZnT + kl; + const int ih = (jH - nsMinH) * nZnT + kl; + + const double r1e_i = r1e[i_in], r1e_o = r1e[i_out]; + const double r1o_i = r1o[i_in], r1o_o = r1o[i_out]; + const double rue_i = rue[i_in], rue_o = rue[i_out]; + const double ruo_i = ruo[i_in], ruo_o = ruo[i_out]; + const double zue_i = zue[i_in], zue_o = zue[i_out]; + const double zuo_i = zuo[i_in], zuo_o = zuo[i_out]; + + gsqrt[ih] = tau[ih] * r12[ih]; + + guu[ih] = 0.5 * ((rue_i * rue_i + zue_i * zue_i) + + (rue_o * rue_o + zue_o * zue_o) + + sF_i * (ruo_i * ruo_i + zuo_i * zuo_i) + + sF_o * (ruo_o * ruo_o + zuo_o * zuo_o)) + + sH * ((rue_i * ruo_i + zue_i * zuo_i) + + (rue_o * ruo_o + zue_o * zuo_o)); + + gvv[ih] = 0.5 * (r1e_i * r1e_i + r1e_o * r1e_o + sF_i * r1o_i * r1o_i + + sF_o * r1o_o * r1o_o) + + sH * (r1e_i * r1o_i + r1e_o * r1o_o); + + if (lthreed) { + const double rve_i = rve[i_in], rve_o = rve[i_out]; + const double rvo_i = rvo[i_in], rvo_o = rvo[i_out]; + const double zve_i = zve[i_in], zve_o = zve[i_out]; + const double zvo_i = zvo[i_in], zvo_o = zvo[i_out]; + + guv[ih] = 0.5 * ((rue_i * rve_i + zue_i * zve_i) + + (rue_o * rve_o + zue_o * zve_o) + + sF_i * (ruo_i * rvo_i + zuo_i * zvo_i) + + sF_o * (ruo_o * rvo_o + zuo_o * zvo_o) + + sH * ((rue_i * rvo_i + zue_i * zvo_i) + + (rue_o * rvo_o + zue_o * zvo_o) + + (rve_i * ruo_i + zve_i * zuo_i) + + (rve_o * ruo_o + zve_o * zuo_o))); + + gvv[ih] += 0.5 * ((rve_i * rve_i + zve_i * zve_i) + + (rve_o * rve_o + zve_o * zve_o) + + sF_i * (rvo_i * rvo_i + zvo_i * zvo_i) + + sF_o * (rvo_o * rvo_o + zvo_o * zvo_o)) + + sH * ((rve_i * rvo_i + zve_i * zvo_i) + + (rve_o * rvo_o + zve_o * zvo_o)); + } + } + } +} + +} // namespace vmecpp + +#endif // VMECPP_VMEC_IDEAL_MHD_MODEL_METRIC_KERNEL_H_