Fix JAX extension build with NVTE_UB_WITH_MPI=1#2835
Conversation
Greptile SummaryThis PR fixes a runtime undefined symbol error when building the JAX extension with Confidence Score: 5/5Safe to merge — targeted bug fix with correct implementation and no regressions introduced. The fix is minimal and correct: it adds the missing No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Build with NVTE_UB_WITH_MPI=1] --> B{Extension type}
B -->|JAX| C[jax.py: setup_jax_extension]
B -->|PyTorch| D[pytorch.py: setup_pytorch_extension]
C --> E[setup_mpi_flags - utils.py]
D --> E
E --> F{NVTE_UB_WITH_MPI set?}
F -->|No| G[No MPI flags added]
F -->|Yes| H{MPI_HOME set?}
H -->|No| I[Assert error: MPI_HOME required]
H -->|Yes| J[Add MPI include dir]
J --> K[Add -DNVTE_UB_WITH_MPI flag]
K --> L[Extension compiled with matching symbol mangling]
L --> M[libtransformer_engine.so symbols resolved at runtime]
Reviews (4): Last reviewed commit: "Merge branch 'main' into main" | Re-trigger Greptile |
build_tools/jax.py
Outdated
| assert ( | ||
| os.getenv("MPI_HOME") is not None |
There was a problem hiding this comment.
Empty
MPI_HOME string bypasses the guard
os.getenv("MPI_HOME") returns None only when the variable is unset. If a user exports MPI_HOME="" (empty string), the assert passes (empty string is not None), and Path("") silently resolves to the current working directory — not a valid MPI installation — causing confusing compile errors downstream.
Consider checking for a non-empty value:
| assert ( | |
| os.getenv("MPI_HOME") is not None | |
| mpi_home = os.getenv("MPI_HOME") | |
| assert mpi_home, ( | |
| "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" | |
| ) | |
| mpi_path = Path(mpi_home) |
This also avoids calling os.getenv("MPI_HOME") twice (once in the assert, once for Path(...)). Note: the same pattern exists in build_tools/pytorch.py line 71–74.
Signed-off-by: Gaetan Lepage <gaetan@glepage.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 |
Description
When building Transformer Engine with
NVTE_UB_WITH_MPI=1andNVTE_FRAMEWORK=pytorch,jax, the JAX extension (transformer_engine_jax) fails to load at runtime with an undefined symbol error, while the PyTorch extension works fine.In
userbuffers.h, theExtCommtype is conditionally defined based onNVTE_UB_WITH_MPI:This type flows into
ExtAllgatherOpandExtBarrierOp, which are parameters of theCommOverlapP2PBaseconstructor.This means the constructor has a different mangled symbol name depending on whether
NVTE_UB_WITH_MPIis defined.The core library (
libtransformer_engine.so) is built via CMake, which correctly sets-DNVTE_UB_WITH_MPI.The PyTorch extension also adds this flag.
However, the JAX extension is missing this flag entirely.
As a result,
transformer_engine_jax.sois compiled expecting theconst char *variant of the constructor, whilelibtransformer_engine.soonly exports theMPI_Commvariant, causing an undefined symbol error at import time.Type of change
Changes
This PR adds the MPI include path and
-DNVTE_UB_WITH_MPIcompile definition to the JAX extension build, mirroring the existing handling inbuild_tools/pytorch.py.Checklist: