jax
flax>=0.7.1
transformer_engine_rocm7==2.14.0.dev0+rocm7.15.0a20260627.a4e9dd6
