|
1 | 1 | import threading |
2 | | -from typing import TYPE_CHECKING |
| 2 | +from typing import TYPE_CHECKING, Dict, List |
3 | 3 |
|
4 | 4 | from ray._private.custom_types import TensorTransportEnum |
5 | 5 | from ray.experimental.gpu_object_manager.collective_tensor_transport import ( |
|
11 | 11 | from ray.experimental.gpu_object_manager.tensor_transport_manager import ( |
12 | 12 | TensorTransportManager, |
13 | 13 | ) |
| 14 | +from ray.util.annotations import PublicAPI |
14 | 15 |
|
15 | 16 | if TYPE_CHECKING: |
16 | 17 | import torch |
17 | 18 |
|
18 | 19 |
|
19 | 20 | # Class definitions for transport managers |
20 | | -transport_manager_classes: dict[str, TensorTransportManager] = { |
21 | | - "NIXL": NixlTensorTransport, |
22 | | - "GLOO": CollectiveTensorTransport, |
23 | | - "NCCL": CollectiveTensorTransport, |
24 | | -} |
25 | | - |
26 | | -transport_devices = { |
27 | | - "NIXL": ["cuda", "cpu"], |
28 | | - "GLOO": ["cpu"], |
29 | | - "NCCL": ["cuda"], |
30 | | -} |
| 21 | +transport_manager_classes: Dict[str, type[TensorTransportManager]] = {} |
31 | 22 |
|
| 23 | +transport_devices: Dict[str, List[str]] = {} |
32 | 24 |
|
33 | 25 | # Singleton instances of transport managers |
34 | | -transport_managers = {} |
| 26 | +transport_managers: Dict[str, TensorTransportManager] = {} |
35 | 27 |
|
36 | 28 | transport_managers_lock = threading.Lock() |
37 | 29 |
|
38 | 30 |
|
| 31 | +@PublicAPI(stability="alpha") |
| 32 | +def register_tensor_transport( |
| 33 | + transport_name: str, |
| 34 | + devices: List[str], |
| 35 | + transport_manager_class: type[TensorTransportManager], |
| 36 | +): |
| 37 | + """ |
| 38 | + Register a new tensor transport for use in Ray. |
| 39 | +
|
| 40 | + Args: |
| 41 | + transport_name: The name of the transport protocol. |
| 42 | + devices: List of device types supported by this transport (e.g., ["cuda", "cpu"]). |
| 43 | + transport_manager_class: A class that implements TensorTransportManager. |
| 44 | +
|
| 45 | + Raises: |
| 46 | + ValueError: If transport_manager_class is not a class or does not subclass TensorTransportManager. |
| 47 | + """ |
| 48 | + global transport_manager_classes |
| 49 | + global transport_devices |
| 50 | + |
| 51 | + if not issubclass(transport_manager_class, TensorTransportManager): |
| 52 | + raise ValueError( |
| 53 | + f"transport_manager_class {transport_manager_class.__name__} must be a subclass of TensorTransportManager." |
| 54 | + ) |
| 55 | + |
| 56 | + transport_name = transport_name.upper() |
| 57 | + transport_manager_classes[transport_name] = transport_manager_class |
| 58 | + transport_devices[transport_name] = devices |
| 59 | + |
| 60 | + |
| 61 | +register_tensor_transport("NIXL", ["cuda", "cpu"], NixlTensorTransport) |
| 62 | +register_tensor_transport("GLOO", ["cpu"], CollectiveTensorTransport) |
| 63 | +register_tensor_transport("NCCL", ["cuda"], CollectiveTensorTransport) |
| 64 | + |
| 65 | + |
39 | 66 | def get_tensor_transport_manager( |
40 | 67 | transport_name: str, |
41 | 68 | ) -> "TensorTransportManager": |
|
0 commit comments