Skip to content

Commit 20918be

Browse files
committed
[core][rdt] Register your own transport at runtime for RDT
Signed-off-by: dayshah <dhyey2019@gmail.com>
1 parent a6d1bcf commit 20918be

File tree

5 files changed

+66
-16
lines changed

5 files changed

+66
-16
lines changed

doc/source/ray-core/api/direct-transport.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ Advanced APIs
3434
:nosignatures:
3535
:toctree: doc/
3636

37-
ray.experimental.wait_tensor_freed
37+
ray.experimental.wait_tensor_freed
38+
ray.experimental.register_tensor_transport

doc/source/ray-core/direct-transport.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,22 @@ You can also use NIXL to retrieve the result from references created by :func:`r
251251
:start-after: __nixl_put__and_get_start__
252252
:end-before: __nixl_put__and_get_end__
253253

254+
255+
Registering a new tensor transport
256+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
257+
258+
Ray allows users to register new tensor transports for use in RDT at runtime. To register a new tensor transport, use the :func:`ray.experimental.register_tensor_transport <ray.experimental.register_tensor_transport>` function.
259+
To implement a new tensor transport, you need to implement the abstract interface defined in :class:`ray.experimental.gpu_object_manager.tensor_transport_manager.TensorTransportManager`.
260+
Then you can simply give `register_tensor_transport` the transport name, devices, and the class that implements `TensorTransportManager`.
261+
NIXL, NCCL, and GLOO are registered through this API as well, see ``nixl_tensor_transport.py`` for a reference example.
262+
263+
.. code-block:: python
264+
265+
from ray.experimental.gpu_object_manager import register_tensor_transport
266+
267+
register_tensor_transport("NIXL", ["cuda", "cpu"], NixlTensorTransport)
268+
269+
254270
Summary
255271
-------
256272

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from ray.experimental.dynamic_resources import set_resource
2-
from ray.experimental.gpu_object_manager import GPUObjectManager, wait_tensor_freed
2+
from ray.experimental.gpu_object_manager import (
3+
GPUObjectManager,
4+
register_tensor_transport,
5+
wait_tensor_freed,
6+
)
37
from ray.experimental.locations import get_local_object_locations, get_object_locations
48

59
__all__ = [
@@ -8,4 +12,5 @@
812
"set_resource",
913
"GPUObjectManager",
1014
"wait_tensor_freed",
15+
"register_tensor_transport",
1116
]

python/ray/experimental/gpu_object_manager/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
GPUObjectManager,
33
wait_tensor_freed,
44
)
5+
from ray.experimental.gpu_object_manager.util import register_tensor_transport
56

6-
__all__ = ["GPUObjectManager", "wait_tensor_freed"]
7+
__all__ = ["GPUObjectManager", "wait_tensor_freed", "register_tensor_transport"]

python/ray/experimental/gpu_object_manager/util.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import threading
2-
from typing import TYPE_CHECKING
2+
from typing import TYPE_CHECKING, Dict, List
33

44
from ray._private.custom_types import TensorTransportEnum
55
from ray.experimental.gpu_object_manager.collective_tensor_transport import (
@@ -11,31 +11,58 @@
1111
from ray.experimental.gpu_object_manager.tensor_transport_manager import (
1212
TensorTransportManager,
1313
)
14+
from ray.util.annotations import PublicAPI
1415

1516
if TYPE_CHECKING:
1617
import torch
1718

1819

1920
# Class definitions for transport managers
20-
transport_manager_classes: dict[str, TensorTransportManager] = {
21-
"NIXL": NixlTensorTransport,
22-
"GLOO": CollectiveTensorTransport,
23-
"NCCL": CollectiveTensorTransport,
24-
}
25-
26-
transport_devices = {
27-
"NIXL": ["cuda", "cpu"],
28-
"GLOO": ["cpu"],
29-
"NCCL": ["cuda"],
30-
}
21+
transport_manager_classes: Dict[str, type[TensorTransportManager]] = {}
3122

23+
transport_devices: Dict[str, List[str]] = {}
3224

3325
# Singleton instances of transport managers
34-
transport_managers = {}
26+
transport_managers: Dict[str, TensorTransportManager] = {}
3527

3628
transport_managers_lock = threading.Lock()
3729

3830

31+
@PublicAPI(stability="alpha")
32+
def register_tensor_transport(
33+
transport_name: str,
34+
devices: List[str],
35+
transport_manager_class: type[TensorTransportManager],
36+
):
37+
"""
38+
Register a new tensor transport for use in Ray.
39+
40+
Args:
41+
transport_name: The name of the transport protocol.
42+
devices: List of device types supported by this transport (e.g., ["cuda", "cpu"]).
43+
transport_manager_class: A class that implements TensorTransportManager.
44+
45+
Raises:
46+
ValueError: If transport_manager_class is not a class or does not subclass TensorTransportManager.
47+
"""
48+
global transport_manager_classes
49+
global transport_devices
50+
51+
if not issubclass(transport_manager_class, TensorTransportManager):
52+
raise ValueError(
53+
f"transport_manager_class {transport_manager_class.__name__} must be a subclass of TensorTransportManager."
54+
)
55+
56+
transport_name = transport_name.upper()
57+
transport_manager_classes[transport_name] = transport_manager_class
58+
transport_devices[transport_name] = devices
59+
60+
61+
register_tensor_transport("NIXL", ["cuda", "cpu"], NixlTensorTransport)
62+
register_tensor_transport("GLOO", ["cpu"], CollectiveTensorTransport)
63+
register_tensor_transport("NCCL", ["cuda"], CollectiveTensorTransport)
64+
65+
3966
def get_tensor_transport_manager(
4067
transport_name: str,
4168
) -> "TensorTransportManager":

0 commit comments

Comments
 (0)