Strip local version labels from package version checks#2858
Strip local version labels from package version checks#2858pstjohn wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Pre-compiled Flash Attention wheels (e.g. from mjun0812/flash-attention-prebuild-wheels) embed build metadata in their package version string (e.g. "2.8.3+cu130torch2.11"). While flash_attn.__version__ returns the clean "2.8.3", TE reads the version via importlib.metadata which returns the full string including the local segment. Under PEP 440, "2.8.3+local" > "2.8.3", causing version range checks like `min_version <= version <= max_version` to incorrectly reject a compatible installation. Use `Version.public` to strip the local label before comparison at all `get_pkg_version` call sites (flash-attn, flash-attn-3, jax). Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Greptile SummaryThis PR fixes a version comparison bug where pre-compiled wheels embedding local build metadata (e.g. Confidence Score: 5/5Safe to merge — the fix is correct, minimal, and covers all affected call sites. All three version-comparison call sites are correctly patched using PEP 440's .public property. No P0/P1 issues found. The only feedback is a P2 style suggestion to centralise the double-wrapping pattern into a helper function. No files require special attention.
|
| Filename | Overview |
|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/backends.py | Strips local version labels from flash-attn v2 and v3 metadata before comparison by double-wrapping with PkgVersion().public; fix is correct and complete for both call sites. |
| transformer_engine/jax/version_utils.py | Applies the same local-label-stripping fix to the JAX version check; semantically correct. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["importlib.metadata.version(pkg)"] --> B["Returns full string\ne.g. '2.8.3+cu130torch2.11'"]
B --> C["PkgVersion(full_string)"]
C --> D[".public property\nstrips local label"]
D --> E["'2.8.3'"]
E --> F["PkgVersion('2.8.3')"]
F --> G{"min_version <= version\n<= max_version"}
G -->|"Pass ✓"| H["Install accepted"]
G -->|"Fail ✗"| I["Install rejected\n(was wrong without fix)"]
Reviews (2): Last reviewed commit: "Merge branch 'main' into pstjohn/strip-p..." | Re-trigger Greptile
|
/te-ci |
Pre-compiled Flash Attention wheels (e.g. from
mjun0812/flash-attention-prebuild-wheels) embed build metadata in their package version string (e.g. "2.8.3+cu130torch2.11"). While flash_attn.version returns the clean "2.8.3", TE reads the version via importlib.metadata which returns the full string including the local segment. Under PEP 440, "2.8.3+local" > "2.8.3", causing version range checks like
min_version <= version <= max_versionto incorrectly reject a compatible installation.Use
Version.publicto strip the local label before comparison at allget_pkg_versioncall sites (flash-attn, flash-attn-3, jax)