Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
434 commits
Select commit Hold shift + click to select a range
59cca07
Merge branch 'ku_rc/anistropy' into dp/laplace
unalmis Feb 20, 2025
e2f7ed1
move docstring to correct test
unalmis Feb 20, 2025
4e77f0f
cosmetic change to avoid confusion
unalmis Feb 20, 2025
ffcc6a1
better comment in test
unalmis Feb 20, 2025
faa14e6
Remove image compare for now
unalmis Feb 20, 2025
9d09393
Halfway through debugging
unalmis Feb 21, 2025
7a1b30d
Add nfp loop wrapper
unalmis Feb 22, 2025
7ec2f7c
Fix laplace solver
unalmis Feb 22, 2025
e2a9b21
Add markers for tests
unalmis Feb 22, 2025
1a73002
Merge branch 'master' into ku/biot
unalmis Feb 22, 2025
22f1dcd
Merge branch 'ku/biot' into dp/laplace
unalmis Feb 22, 2025
c7b996d
Finalize API for vacuum solver
unalmis Feb 23, 2025
ae53815
Allow solving overdetermined system and passing remaining tests
unalmis Feb 23, 2025
50b2ac3
Add warnings about JAX bug
unalmis Feb 23, 2025
a1d4bf0
grad(Phi) not grad(phi)
unalmis Feb 23, 2025
4d73f92
Fix bug for setting default variable
unalmis Feb 23, 2025
869c81b
Demonstrate #1599
unalmis Feb 23, 2025
5ec2e67
Merge branch 'master' into ku/biot
dpanici Feb 24, 2025
767828f
Merge branch 'master' into ku/biot
dpanici Feb 26, 2025
db8962b
modify docstrings for fxns that dont use chunk_size
dpanici Feb 26, 2025
2d4b1ea
Merge branch 'ku/biot' into dp/laplace
dpanici Feb 26, 2025
3cf860e
Merge branch 'master' into dp/laplace
unalmis Mar 1, 2025
c454b23
Rename laplace to vacuum
unalmis Mar 1, 2025
bce7e69
Use master compute data
unalmis Mar 1, 2025
b298b3f
Revert nfp loop
unalmis Mar 1, 2025
6ad5988
Remove thing moved to other pr
unalmis Mar 1, 2025
c7687e0
Use e^rho*sqrt(g) instead of n_rho
unalmis Mar 1, 2025
4102724
Update changelog
unalmis Mar 1, 2025
7b2fdd9
Fix mistake with sfaenormalize
unalmis Mar 1, 2025
5acf17a
Fix local_params func
unalmis Mar 3, 2025
22bcfa9
same as previous commit
unalmis Mar 3, 2025
7436d00
make sure int are used so that jax compiler works better
unalmis Mar 3, 2025
a8b2da9
Use get_interpolatar func in free bdry
unalmis Mar 3, 2025
9a61e25
Merge branch 'master' into dp/laplace
unalmis Mar 3, 2025
f48ce7a
Merge branch 'master' into dp/laplace
unalmis Mar 9, 2025
77b88e4
Merge branch 'master' into dp/laplace
unalmis Mar 10, 2025
52f789d
Make sure errors and warnings print in expected color
unalmis Mar 10, 2025
8625275
Add mode_idx option ransform
unalmis Mar 10, 2025
6aedc7d
Make inversion more efficient
unalmis Mar 10, 2025
515e0fd
Make inversion more efficient (#1637)
unalmis Mar 11, 2025
4f7c97d
revert changes to transform
unalmis Mar 11, 2025
1c8161b
Add fixed point solve for Phi
unalmis Mar 12, 2025
0a74536
Merge branch 'master' into dp/laplace
unalmis Mar 12, 2025
076d31a
Merge branch 'dp/laplace' into ku/fp
unalmis Mar 12, 2025
0def4de
Merge branch 'ku/cholesky' into dp/laplace
unalmis Mar 12, 2025
a74af4a
progressing
unalmis Mar 12, 2025
300a5f3
Add edited version of interp_2d from interpax
unalmis Mar 12, 2025
cc0ea98
Add fixed point routine
unalmis Mar 13, 2025
fd09d6a
Merge branch 'master' into dp/laplace
unalmis Mar 14, 2025
11ddd80
Merge branch 'master' into dp/laplace
YigitElma Mar 16, 2025
ebf642f
Merge branch 'master' into dp/laplace
unalmis Mar 20, 2025
4a6767c
debugging fixed point
unalmis Mar 20, 2025
0351391
Better initial guess
unalmis Mar 21, 2025
c6dc729
Speed of matmul
unalmis Mar 22, 2025
e2a308f
j
unalmis Mar 22, 2025
188b6b3
Merge branch 'master' into dp/laplace
unalmis Apr 5, 2025
e33ec14
Fix grammer
unalmis Apr 5, 2025
3a9f3a9
Second pass at new free boundary; taking break for the night
unalmis Apr 9, 2025
57fa57a
Fix import error
unalmis Apr 9, 2025
38b8acf
Simplified Bout computation in free boundary (#1600)
unalmis Apr 10, 2025
606ccfe
Add ability to evaluate at R phi Z in interior
unalmis Apr 12, 2025
c173e54
Merge branch 'master' into dp/laplace
unalmis Apr 12, 2025
3990080
off surface eval
unalmis Apr 12, 2025
6645de0
Merge branch 'master' into dp/laplace
unalmis Apr 17, 2025
62b5e98
Merge branch 'ku/pos' into dp/laplace
unalmis Apr 17, 2025
9866167
Merge branch 'master' into dp/laplace
unalmis Apr 17, 2025
25a2fa6
Change iterations to ones that should in theory converge
unalmis Apr 17, 2025
6fffa46
Add exterior solver
unalmis Apr 18, 2025
63aab5b
Merge branch 'master' into dp/laplace
unalmis Apr 18, 2025
19ed1d4
add missing key
dpanici Apr 18, 2025
ca0b397
Update alternative free bdry solver for left inverse surface gradient
unalmis Apr 19, 2025
1299544
speed up inversion
unalmis Apr 19, 2025
4702be8
Bypass large tensor product
unalmis Apr 19, 2025
6c53b32
Handle secular terms of B_coil properly
unalmis Apr 19, 2025
00f83e6
Add axis limit
unalmis Apr 19, 2025
3990871
Update docstrings to warn about jax bug
unalmis May 25, 2025
13a9f56
Merge branch 'master' into dp/laplace
unalmis May 26, 2025
d75b2ad
Merge branch 'master' into dp/laplace
unalmis Jun 1, 2025
edb415b
Merge branch 'master' into dp/laplace
unalmis Jun 5, 2025
a19abff
Pushing working commit
unalmis Jun 6, 2025
4b2b1f5
push updated test to sync with partner
unalmis Jun 6, 2025
321b002
fix sym bug
dpanici Jun 6, 2025
a3c25ab
Update basis xyz
unalmis Jun 8, 2025
bf8c24f
Merge branch 'dp/laplace' of github.com:PlasmaControl/DESC into dp/la…
dpanici Jun 9, 2025
fe5d5ff
change resolution
unalmis Jun 10, 2025
8f7a87c
Update README.rst
unalmis Jun 10, 2025
50cee08
Merge branch 'master' into dp/laplace
unalmis Jun 10, 2025
bea5cb6
change exterior test so that singularity is outside domain and thus t…
dpanici Jun 10, 2025
a08f55d
Merge branch 'dp/laplace' of github.com:PlasmaControl/DESC into dp/la…
dpanici Jun 10, 2025
9963681
Add test for splitting potential into secular and periodic parts
unalmis Jun 10, 2025
bd048f5
Debugging test
unalmis Jun 10, 2025
3124441
Merge branch 'master' into dp/laplace
unalmis Jun 12, 2025
9613f16
Merge branch 'master' into dp/laplace
unalmis Jun 19, 2025
e1595db
update
unalmis Jun 20, 2025
953d371
working commit for free bdry
dpanici Jun 26, 2025
7d18caa
Merge branch 'master' into dp/laplace
unalmis Jun 29, 2025
b626fb1
no change
unalmis Jun 29, 2025
beaef69
Updating naming for consistency with paper
unalmis Jun 30, 2025
6934044
Decreaes chunk size
unalmis Jun 30, 2025
f86c264
fix failing test
unalmis Jun 30, 2025
59cdb91
Merge branch 'master' into dp/laplace
unalmis Jul 1, 2025
1d6c1ca
Updating sign conventions to match paper
unalmis Jul 3, 2025
0bed5a5
fix comment
unalmis Jul 3, 2025
060bf02
Remove singularity
unalmis Jul 6, 2025
36f5a13
Merge branch 'master' into dp/laplace
unalmis Jul 10, 2025
9e08f60
non-singular Laplace solver and new magnetic field API (#1805)
unalmis Jul 11, 2025
9984de1
Fix circular import
unalmis Jul 11, 2025
e0709be
Fixing more imports
unalmis Jul 11, 2025
c4fe291
Remove out of date comment
unalmis Jul 11, 2025
3973dc9
Fix jit stuff from previous commit
unalmis Jul 11, 2025
baaf1fb
Merge branch 'master' into dp/laplace
unalmis Jul 11, 2025
9754681
resolving merge conflicts
unalmis Jul 11, 2025
d6cde7d
Add options for convergence plots
unalmis Jul 12, 2025
3d5fde6
Merge branch 'master' into dp/laplace
unalmis Jul 12, 2025
58d1e65
Add marker to skip plot test
unalmis Jul 12, 2025
141a71e
Add another test
unalmis Jul 12, 2025
9c5cac7
Update
unalmis Jul 18, 2025
19b2a4b
Merge branch 'master' into dp/laplace
unalmis Jul 18, 2025
9c30209
Remove old commetn
unalmis Jul 18, 2025
57c3eff
Fix failing test
unalmis Jul 19, 2025
80a3d7d
Remove old warnings
unalmis Jul 19, 2025
79ccfc9
jit fixed point method
unalmis Jul 20, 2025
e64e649
Add plotting utils for convergence
unalmis Jul 21, 2025
d3e2b28
fix comment
unalmis Jul 21, 2025
5c53628
increase font size
unalmis Jul 21, 2025
0350064
Adds some compute quantities
unalmis Jul 22, 2025
599691b
Merge branch 'master' into dp/laplace
unalmis Jul 22, 2025
7b08b51
update compute data for new quantities
unalmis Jul 22, 2025
9ede6f9
use master compute data file
unalmis Jul 22, 2025
192caa7
free surf parallel (#1822)
unalmis Jul 23, 2025
52ad554
Apply suggestions from code review
unalmis Jul 23, 2025
a28eca6
Free surface potentials have stellarator symmetry
unalmis Jul 24, 2025
682c460
Don't use eq.surface since that is not updated by optimizer
unalmis Jul 24, 2025
b756377
remove no longer needed code
unalmis Jul 24, 2025
e493493
Add coil grid as kwarg
unalmis Jul 27, 2025
3920fd3
Build psinv in get_transforms
unalmis Jul 27, 2025
afd1d84
Revert "working commit for free bdry"
dpanici Jul 28, 2025
418393c
Merge branch 'dp/laplace' of github.com:PlasmaControl/DESC into dp/la…
dpanici Jul 28, 2025
75e57c5
Add B_coil_chunk_size and don't interpolate gradients
unalmis Jul 28, 2025
e8631ba
Merge branch 'dp/laplace' of github.com:PlasmaControl/DESC into dp/la…
dpanici Jul 28, 2025
eb2c516
Dp/laplace 2 (#1824)
dpanici Jul 28, 2025
ac936ef
Use 10000 instead of jnp.inf
unalmis Jul 28, 2025
d1c55b8
Merge branch 'master' into dp/laplace
unalmis Aug 1, 2025
50cd66f
Add static attributes to freesurfaceerror objective
unalmis Aug 1, 2025
fb7b65d
Remove old code
unalmis Aug 1, 2025
aa5a8ea
Reduce memory usage
unalmis Aug 1, 2025
95d4ea8
Dummy commit to restart tests
unalmis Aug 1, 2025
a0f72c0
Reduce chunk size in test
unalmis Aug 1, 2025
e5d990b
switch to lineax
unalmis Aug 1, 2025
d5b0ad3
Clean up if logic
unalmis Aug 1, 2025
b0256dc
Set well_posed to False for interiorNeumann because ill-conditioned
unalmis Aug 2, 2025
58b101c
Add lineax tags
unalmis Aug 2, 2025
ec8f495
Don't assume integrals were computed with sufficient accuracy
unalmis Aug 2, 2025
12b18c2
same as previous commit
unalmis Aug 2, 2025
f4272a8
set well_posed=False to avoid NaNs
dpanici Aug 2, 2025
d054f9c
Merge branch 'dp/laplace' of github.com:PlasmaControl/DESC into dp/la…
dpanici Aug 2, 2025
7a61ace
Merge branch 'master' into dp/laplace
unalmis Aug 2, 2025
15ddd99
Add Neumann free surface
unalmis Aug 2, 2025
41a61e0
fix string for old python versions
unalmis Aug 2, 2025
60c1608
Add warnings about symmetry
unalmis Aug 2, 2025
fea8cbc
Fix spelling mistakes
unalmis Aug 2, 2025
91a12ea
Remove warnings about symmetry of basis
unalmis Aug 2, 2025
0cc5ed4
Update docstring in laplace.py
unalmis Aug 2, 2025
bd8eab7
Update _laplace.py doc comment
unalmis Aug 2, 2025
88232cf
Update free surface docstring
unalmis Aug 2, 2025
a09e872
Update _free_boundary.py
unalmis Aug 2, 2025
949018e
Update _laplace.py kwarg docstring
unalmis Aug 2, 2025
bccc266
Update _free_boundary.py docstring
unalmis Aug 2, 2025
97d72c6
Update comment in docstring
unalmis Aug 2, 2025
8590852
Merge branch 'master' into dp/laplace
unalmis Aug 4, 2025
d818a97
Add some comments to test_exterior_Neumann
unalmis Aug 8, 2025
2f84eeb
Merge branch 'master' into dp/laplace
unalmis Aug 8, 2025
c9663c0
Add unused dependency comment
unalmis Aug 8, 2025
ce84c49
Merge branch 'master' into dp/laplace
unalmis Aug 15, 2025
a12a69f
merging
unalmis Aug 15, 2025
c4f3b6e
Merge branch 'master' into dp/laplace
rahulgaur104 Aug 20, 2025
e6d68e1
mark function as private
unalmis Aug 23, 2025
a43da7b
Merge branch 'master' into dp/laplace
unalmis Aug 23, 2025
9416064
Pulling changes downstream from https://github.com/f0uriest/interpax/…
unalmis Aug 24, 2025
cf919bb
Increase coverage
unalmis Aug 25, 2025
f09eb15
Set to inf as that was not cause of nan
unalmis Sep 2, 2025
16a0ba5
Merge branch 'master' into dp/laplace
unalmis Sep 2, 2025
9346c8e
fix bad merge
unalmis Sep 2, 2025
8861bdf
increase codecov
unalmis Sep 3, 2025
c256dc6
Default to fixed point method and add stop tolerance option
unalmis Sep 3, 2025
bd7c1da
Merge branch 'master' into dp/laplace
unalmis Sep 4, 2025
a3d9915
Merge branch 'master' into dp/laplace
unalmis Sep 4, 2025
8a889af
Merge branch 'master' into dp/laplace
unalmis Sep 9, 2025
6853d59
Merge branch 'master' into dp/laplace
unalmis Sep 12, 2025
52eedd8
Cast theta, rho, zeta arrays to jax arrays
unalmis Sep 13, 2025
212372a
Update singularities.py
unalmis Sep 13, 2025
2419287
Remove unused kernel and unused warn_dft kwarg
unalmis Sep 14, 2025
fb2d8b7
some plumbing to enable interpolation to less dense grid
unalmis Sep 14, 2025
3995d06
fix master compue data
unalmis Sep 14, 2025
a608606
improve test
unalmis Sep 15, 2025
4ad6d33
cosmetic change
unalmis Sep 15, 2025
e286daa
same as previous commit
unalmis Sep 15, 2025
be2cbc1
cast jax array to int
unalmis Sep 15, 2025
cabe5ff
Add warning to objective to finish pull request.
unalmis Sep 15, 2025
82d2fe6
Update changelog
unalmis Sep 15, 2025
602e233
Merge branch 'master' into dp/laplace
unalmis Sep 21, 2025
9a5009e
Merge branch 'master' into dp/laplace
unalmis Sep 25, 2025
317f27c
Merge branch 'master' into dp/laplace
unalmis Oct 1, 2025
3412523
Merge branch 'master' into dp/laplace
dpanici Oct 7, 2025
cf040ae
Trial 1: Merge branch 'master' into dp/laplace
unalmis Oct 9, 2025
5a47bad
Add missing test from master
unalmis Oct 9, 2025
2907cc4
Add cosmetic space deleted during merge
unalmis Oct 9, 2025
207a485
Merge branch 'master' into dp/laplace
dpanici Oct 9, 2025
aff97e2
Merge branch 'master' into dp/laplace
unalmis Oct 13, 2025
4b86872
Merge branch 'master' into dp/laplace
unalmis Oct 13, 2025
e173c56
Merge branch 'master' into dp/laplace
dpanici Oct 15, 2025
b75accb
add anderson acceleration implementation
dpanici Oct 16, 2025
09830a0
change default params
dpanici Oct 16, 2025
047d34d
update beta to not use relaxation, as convergence seems good without it
dpanici Oct 19, 2025
14e96fe
Merge branch 'master' into dp/laplace
unalmis Nov 11, 2025
53e0c68
change default back to simple so tests pass, strange that with anders…
dpanici Nov 11, 2025
28f2f3f
use code from interpax_fft
unalmis Dec 16, 2025
4677d43
Merge branch 'master' into dp/laplace
unalmis Dec 16, 2025
5315fc8
fix requirements.txt
unalmis Dec 16, 2025
29b7568
Merge branch 'master' into dp/laplace
unalmis Feb 19, 2026
fe0473e
merge
unalmis Feb 19, 2026
def22b7
Merge branch 'master' into dp/laplace
unalmis Feb 27, 2026
dad165d
Merge branch 'master' into dp/laplace
dpanici Mar 3, 2026
29636ba
Merge branch 'master' into dp/laplace
unalmis Mar 4, 2026
02daa92
Merge branch 'master' into dp/laplace
unalmis Apr 1, 2026
b7cec8f
Merge branch 'master' into dp/laplace
unalmis Apr 2, 2026
fdf3dd4
Update requirements.txt
unalmis Apr 2, 2026
27548ed
Merge branch 'master' into dp/laplace
unalmis Apr 8, 2026
0a6e0b8
Merge branch 'master' into dp/laplace
unalmis Apr 9, 2026
1b14537
Merge branch 'master' into dp/laplace
unalmis Apr 17, 2026
ee52f72
Merge branch 'master' into dp/laplace
unalmis Apr 18, 2026
fbab6c0
Merge branch 'master' into dp/laplace
unalmis May 15, 2026
0ce9241
use lineax solver
unalmis May 15, 2026
5668ae3
.
unalmis May 15, 2026
38c1fcb
.
unalmis May 15, 2026
2e16903
add script to profile
unalmis May 16, 2026
d50f976
Resolves performance bottleneck discussed here: https://github.com/Pl…
unalmis May 16, 2026
c24200a
increase codecov
unalmis May 16, 2026
06bc963
don't close over lambda
unalmis May 16, 2026
2221f69
Merge branch 'master' into ku/laplace
unalmis May 18, 2026
d5371aa
Merge branch 'master' into ku/laplace
YigitElma May 19, 2026
444ca37
Merge branch 'master' into ku/laplace
unalmis May 19, 2026
69d556f
obstacle from other prs
unalmis May 20, 2026
45ef6cb
cleanup
unalmis May 30, 2026
c978773
.
unalmis May 30, 2026
ea107ec
.
unalmis May 30, 2026
faa4b2c
.
unalmis May 30, 2026
cf144da
Remove redundant option
unalmis May 31, 2026
b07ec20
.
unalmis Jun 13, 2026
d41a313
.
unalmis Jun 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ v0.17.0
-------

New Features
------------
- Automatically-differentiable, non-singular Laplace BIE solver.
- Improved performance and accuracy of FFT interpolation in singular integrals
([1](https://github.com/f0uriest/interpax/pull/116), [2](https://github.com/f0uriest/interpax/pull/117)).
This is useful for free surface optimization.
- [Plumbing for new magnetic field API](https://github.com/PlasmaControl/DESC/issues/1807).

- Adds particle tracing capabilities in ``desc.particles`` module.
- Particle tracing is done via ``desc.particles.trace_particles`` function.
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Contribute
- `Contributing guidelines <https://github.com/PlasmaControl/DESC/blob/master/CONTRIBUTING.rst>`_
- `Issue Tracker <https://github.com/PlasmaControl/DESC/issues>`_
- `Source Code <https://github.com/PlasmaControl/DESC/>`_
- `Documentation <https://desc-docs.readthedocs.io/>`_
- `Documentation <https://desc-docs.readthedocs.io/en/stable/>`_

.. |License| image:: https://img.shields.io/github/license/PlasmaControl/desc?color=blue&logo=open-source-initiative&logoColor=white
:target: https://github.com/PlasmaControl/DESC/blob/master/LICENSE
Expand Down
89 changes: 50 additions & 39 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,50 @@ def print_backend_info():
)


def _is_converged(residual, tol):
return jnp.sum(residual * residual) <= tol**2


def _lstsq(A, y):
"""Cholesky factorized least-squares.

jnp.linalg.lstsq doesn't have JVP defined and is slower than needed,
so we use regularized cholesky.

For square systems, solves Ax=y directly.
"""
A = jnp.atleast_2d(A)
y = jnp.atleast_1d(y)
eps = jnp.sqrt(jnp.finfo(A.dtype).eps)
if A.shape[-2] == A.shape[-1]:
return jnp.linalg.solve(A, y) if y.size > 1 else jnp.squeeze(y / A)
elif A.shape[-2] > A.shape[-1]:
P = A.T @ A + eps * jnp.eye(A.shape[-1])
return cho_solve(cho_factor(P), A.T @ y)
else:
P = A @ A.T + eps * jnp.eye(A.shape[-2])
return A.T @ cho_solve(cho_factor(P), y)


def _tangent_solve(g, y):
# System is always square.
return _lstsq(jax.jacfwd(g)(y), y)


def _tangent_solve_scalar(g, y):
return y / g(1.0)


def _map(f, xs, *, batch_size=None, in_axes=0, out_axes=0):
"""Generalizes jax.lax.map; uses numpy."""
if not isinstance(xs, np.ndarray):
raise NotImplementedError(
"Require numpy array input, or install jax to support pytrees."
)
xs = np.moveaxis(xs, source=in_axes, destination=0)
return np.stack([f(x) for x in xs], axis=out_axes)


def _diag_to_full(d, e):
j = np.arange(d.shape[-1])
return (
Expand Down Expand Up @@ -360,7 +404,7 @@ def backtrack(xk1, fk1, d):

def condfun(state):
xk1, fk1, k1 = state
return (k1 < maxiter) & (jnp.dot(fk1, fk1) > tol**2)
return (k1 < maxiter) & (~_is_converged(fk1, tol))

def bodyfun(state):
xk1, fk1, k1 = state
Expand All @@ -376,41 +420,17 @@ def bodyfun(state):
else:
return state[0]

def tangent_solve(g, y):
return y / g(1.0)

if full_output:
x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
res, x0, solve, _tangent_solve_scalar, has_aux=True
)
return x, (abs(res), niter)
else:
x = jax.lax.custom_root(res, x0, solve, tangent_solve, has_aux=False)
x = jax.lax.custom_root(
res, x0, solve, _tangent_solve_scalar, has_aux=False
)
return x

def _lstsq(A, y):
"""Cholesky factorized least-squares.

jnp.linalg.lstsq doesn't have JVP defined and is slower than needed,
so we use regularized cholesky.

For square systems, solves Ax=y directly.
"""
A = jnp.atleast_2d(A)
y = jnp.atleast_1d(y)
eps = jnp.sqrt(jnp.finfo(A.dtype).eps)
if A.shape[-2] == A.shape[-1]:
return jnp.linalg.solve(A, y) if y.size > 1 else jnp.squeeze(y / A)
elif A.shape[-2] > A.shape[-1]:
P = A.T @ A + eps * jnp.eye(A.shape[-1])
return cho_solve(cho_factor(P), A.T @ y)
else:
P = A @ A.T + eps * jnp.eye(A.shape[-2])
return A.T @ cho_solve(cho_factor(P), y)

def _tangent_solve(g, y):
return _lstsq(jax.jacfwd(g)(y), y)

def root(
fun,
x0,
Expand Down Expand Up @@ -495,7 +515,7 @@ def backtrack(xk1, fk1, d):

def condfun(state):
xk1, fk1, k1 = state
return (k1 < maxiter) & (jnp.dot(fk1, fk1) > tol**2)
return (k1 < maxiter) & (~_is_converged(fk1, tol))

def bodyfun(state):
xk1, fk1, k1 = state
Expand Down Expand Up @@ -552,15 +572,6 @@ def bodyfun(state):
excluded={"eigvals_only", "select", "select_range", "tol"},
)

def _map(f, xs, *, batch_size=None, in_axes=0, out_axes=0):
"""Generalizes jax.lax.map; uses numpy."""
if not isinstance(xs, np.ndarray):
raise NotImplementedError(
"Require numpy array input, or install jax to support pytrees."
)
xs = np.moveaxis(xs, source=in_axes, destination=0)
return np.stack([f(x) for x in xs], axis=out_axes)

def vmap(fun, in_axes=0, out_axes=0):
"""A numpy implementation of jax.lax.map whose API is a subset of jax.vmap.

Expand Down
37 changes: 22 additions & 15 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None):
else:
lidx = loutidx = np.arange(len(modes))
if (derivatives[1] != 0) or (derivatives[2] != 0):
return jnp.zeros((grid.num_nodes, modes.shape[0]))
return jnp.zeros((grid.num_nodes, self.num_modes))
if not len(modes):
return np.array([]).reshape((grid.num_nodes, 0))

Expand Down Expand Up @@ -537,7 +537,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None):
else:
nidx = noutidx = np.arange(len(modes))
if (derivatives[0] != 0) or (derivatives[1] != 0):
return jnp.zeros((grid.num_nodes, modes.shape[0]))
return jnp.zeros((grid.num_nodes, self.num_modes))
if not len(modes):
return np.array([]).reshape((grid.num_nodes, 0))

Expand Down Expand Up @@ -635,7 +635,13 @@ def _get_modes(self, M, N):
z = np.zeros_like(m)
return np.array([z, m, n]).T

def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None):
def evaluate(
self,
grid,
derivatives=np.array([0, 0, 0]),
modes=None,
**kwargs,
):
"""Evaluate basis functions at specified nodes.

Parameters
Expand All @@ -644,7 +650,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None):
Node coordinates, in (rho,theta,zeta).
derivatives : ndarray of int, shape(num_derivatives,3)
Order of derivatives to compute in (rho,theta,zeta).
modes : ndarray of in, shape(num_modes,3), optional
modes : ndarray of int, shape(num_modes,3), optional
Basis modes to evaluate (if None, full basis is used).

Returns
Expand All @@ -670,7 +676,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None):
midx = moutidx = np.arange(len(modes))
nidx = noutidx = np.arange(len(modes))
if derivatives[0] != 0:
return jnp.zeros((grid.num_nodes, modes.shape[0]))
return jnp.zeros((grid.num_nodes, self.num_modes))
if not len(modes):
return np.array([]).reshape((grid.num_nodes, 0))

Expand All @@ -688,17 +694,18 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None):
_, t, z = grid.nodes.T
_, m, n = modes.T

t = t[tidx]
z = z[zidx]
t = kwargs["t"] if "t" in kwargs else t[tidx]
z = kwargs["z"] if "z" in kwargs else z[zidx]
m = m[midx]
n = n[nidx]

poloidal = fourier(t[:, np.newaxis], m, 1, derivatives[1])
toroidal = fourier(z[:, np.newaxis], n, self.NFP, derivatives[2])
poloidal = poloidal[toutidx][:, moutidx]
toroidal = toroidal[zoutidx][:, noutidx]

return poloidal * toroidal
poloidal = fourier(t[:, np.newaxis], m, 1, derivatives[1])[:, moutidx]
toroidal = fourier(z[:, np.newaxis], n, self.NFP, derivatives[2])[:, noutidx]
if grid.is_meshgrid and grid.num_rho == 1:
return (poloidal[:, np.newaxis] * toroidal[np.newaxis]).reshape(
-1, self.num_modes, order="F"
)
return poloidal[toutidx] * toroidal[zoutidx]

def change_resolution(self, M, N, NFP=None, sym=None):
"""Change resolution of the basis to the given resolutions.
Expand Down Expand Up @@ -881,7 +888,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None):
lmidx = lmoutidx = np.arange(len(modes))
midx = moutidx = np.arange(len(modes))
if derivatives[2] != 0:
return jnp.zeros((grid.num_nodes, modes.shape[0]))
return jnp.zeros((grid.num_nodes, self.num_modes))
if not len(modes):
return np.array([]).reshape((grid.num_nodes, 0))

Expand Down Expand Up @@ -1437,7 +1444,7 @@ def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None):
else:
lidx = loutidx = np.arange(len(modes))
if (derivatives[1] != 0) or (derivatives[2] != 0):
return jnp.zeros((grid.num_nodes, modes.shape[0]))
return jnp.zeros((grid.num_nodes, self.num_modes))
if not len(modes):
return np.array([]).reshape((grid.num_nodes, 0))

Expand Down
5 changes: 3 additions & 2 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,7 @@ def compute(

Returns
-------
data : list of dict of ndarray
data : list[dict[str, jnp.ndarray]]
Computed quantity and intermediate variables, for each coil in the set.
List entries map to coils in coilset, each dict contains data for an
individual coil.
Expand Down Expand Up @@ -1843,6 +1843,7 @@ def body(AB, x):
return AB, None

AB += scan(body, jnp.zeros(coords_nfp.shape), tree_stack(params))[0]

return AB

AB = fori_loop(0, self.NFP, nfp_loop, jnp.zeros_like(coords_rpz))
Expand Down Expand Up @@ -2753,7 +2754,7 @@ def compute(

Returns
-------
data : list of dict of ndarray
data : list[dict[str, jnp.ndarray]]
Computed quantity and intermediate variables, for each coil in the set.
List entries map to coils in coilset, each dict contains data for an
individual coil.
Expand Down
1 change: 1 addition & 0 deletions desc/compute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_fast_ion,
_field,
_geometry,
_laplace,
_metric,
_neoclassical,
_old,
Expand Down
Loading
Loading