Skip to content

Revise JAX intro lecture and add autodiff lecture#513

Open
jstac wants to merge 4 commits intomainfrom
jax-intro-revisions
Open

Revise JAX intro lecture and add autodiff lecture#513
jstac wants to merge 4 commits intomainfrom
jax-intro-revisions

Conversation

@jstac
Copy link
Copy Markdown
Contributor

@jstac jstac commented Apr 5, 2026

Summary

  • Bug fixes: Fixed coefficient mismatch (0.1 * x**2 vs x**2) in NumPy/JAX function comparison; updated deprecated jax.random.PRNGKey to jax.random.key throughout
  • New figures: Added code-generated PRNG key splitting tree diagram and JIT compilation pipeline diagram
  • New section: Added vmap section with mean/median example showing why Python loops are inefficient with JAX
  • Autodiff preview: Reworked the gradients section as a brief preview with forward reference to the new autodiff lecture
  • New lecture: Added "Adventures with Autodiff" (adapted from lecture-jax), covering finite differences vs symbolic vs autodiff, gradient descent with Barzilai-Borwein, and OLS regression
  • Housekeeping: Consolidated all imports into initial cell, added (jax_intro)= reference label, updated _toc.yml

Test plan

  • CI build passes (executes all notebooks)
  • Verify PRNG key splitting tree figure renders correctly
  • Verify JIT pipeline figure renders correctly
  • Verify autodiff lecture cross-references resolve
  • Check vmap example output is clear

🤖 Generated with Claude Code

jstac and others added 4 commits April 5, 2026 14:27
- Fix coefficient mismatch in NumPy vs JAX function comparison
- Update jax.random.PRNGKey to jax.random.key throughout
- Add code-generated figures (PRNG key splitting tree, JIT pipeline)
- Add vmap section with examples and transformation composition
- Rework gradients section as autodiff preview with forward reference
- Add autodiff lecture (adapted from lecture-jax) to TOC
- Consolidate all imports into initial cell

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Replace artificial sum-of-squares vmap example with mean/median statistics
- Add explanation of why Python loops are inefficient with JAX
- Move all imports to initial cell per lecture conventions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Fix variable name typo (z_max_numpy → z_max_numba)
- Fix vmap v2 print label
- Fix garbled em dash
- Consolidate imports to top of lecture
- Use jnp.meshgrid instead of np.meshgrid for JAX arrays
- Replace cm.jet with cm.viridis
- Qualify JAX speed claim re GPU
- Add overall recommendations section synthesizing trade-offs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jstac
Copy link
Copy Markdown
Contributor Author

jstac commented Apr 5, 2026

Detailed changelog

jax_intro.md

Bug fixes:

  • Fixed coefficient mismatch in NumPy vs JAX function comparison (0.1 * x**2 vs x**2 — the two versions of f were computing different functions)
  • Updated jax.random.PRNGKeyjax.random.key (5 occurrences) — PRNGKey is the legacy API

New content:

  • Added code-generated PRNG key splitting tree diagram illustrating how split produces independent keys
  • Added JIT compilation pipeline diagram (Python function → trace → XLA → execution)
  • Added new vmap section with a mean/median example showing why Python loops are inefficient with JAX, and how transformations compose (jit(vmap(...)))
  • Reworked "Gradients" section into "Automatic differentiation: a preview" with forward reference to the new autodiff lecture

Housekeeping:

  • Consolidated all imports into the initial cell (was scattered across 4 locations)
  • Added (jax_intro)= reference label for cross-referencing
  • Rephrased transition text after import consolidation

autodiff.md (new lecture)

  • Adapted from lecture-jax/lectures/autodiff.md
  • Covers: finite differences vs symbolic calculus vs autodiff, differentiating through control flow and interpolation, gradient descent with Barzilai-Borwein, OLS regression example, polynomial regression exercise
  • Updated jax.random.PRNGKeyjax.random.key
  • Moved sympy import to top, dropped unused diff import
  • Removed nvidia-smi cell (GPU admonition covers this)
  • Added pip install cell and intro referencing the jax_intro lecture

numpy_vs_numba_vs_jax.md

Bug fixes:

  • Fixed variable name z_max_numpyz_max_numba for the Numba result
  • Fixed print label "JAX vmap v1" → "v2" in the vmap v2 section
  • Fixed garbled em dash (--—---)
  • Changed np.meshgridjnp.meshgrid when operating on JAX arrays (avoids silent host-to-device transfer)
  • Qualified "significantly faster due to GPU acceleration" → "significantly faster, especially on a GPU"

New content:

  • Added "Overall recommendations" synthesis section covering: JAX for vectorized work, Numba for sequential loops, lax.scan differentiability advantage, and a rule of thumb for choosing between them

Housekeeping:

  • Consolidated all imports to top (numba, lax, partial were mid-lecture)
  • Replaced cm.jet with cm.viridis (perceptually uniform, colorblind-friendly)

_toc.yml

  • Added autodiff after numpy_vs_numba_vs_jax in the "High Performance Computing" section

@jstac
Copy link
Copy Markdown
Contributor Author

jstac commented Apr 5, 2026

please also review this and merge when ready @mmcky .

please check that the figure showing the jax.jit compilation process in jax_intro came out okay.

Once these changes and those in #512 are merged, please make live

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant