From 80702581b61c68de90273fc911d57ac36bc87b8e Mon Sep 17 00:00:00 2001 From: qianduoduo0904 Date: Thu, 8 Dec 2022 08:31:32 +0000 Subject: [PATCH] Fix block size for transfer --- mars/oscar/__init__.py | 1 + mars/oscar/backends/transfer.py | 15 +++++++++++++++ mars/services/storage/transfer.py | 9 +++++---- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/mars/oscar/__init__.py b/mars/oscar/__init__.py index 829ec1f4e2..0541e419bd 100644 --- a/mars/oscar/__init__.py +++ b/mars/oscar/__init__.py @@ -39,6 +39,7 @@ ) from .backends import allocate_strategy from .backends.pool import MainActorPoolType +from .backends.transfer import temp_transfer_block_size from .batch import extensible from .core import ( ActorRef, diff --git a/mars/oscar/backends/transfer.py b/mars/oscar/backends/transfer.py index f4f12f9d42..2559abf33f 100644 --- a/mars/oscar/backends/transfer.py +++ b/mars/oscar/backends/transfer.py @@ -40,6 +40,21 @@ DEFAULT_TRANSFER_BLOCK_SIZE = 4 * 1024**2 +@contextlib.contextmanager +def temp_transfer_block_size(size: int): + global DEFAULT_TRANSFER_BLOCK_SIZE + + if size == DEFAULT_TRANSFER_BLOCK_SIZE: + yield + else: + default_size = DEFAULT_TRANSFER_BLOCK_SIZE + DEFAULT_TRANSFER_BLOCK_SIZE = size + try: + yield + finally: + DEFAULT_TRANSFER_BLOCK_SIZE = default_size + + def _get_buffer_size(buf) -> int: try: return buf.nbytes diff --git a/mars/services/storage/transfer.py b/mars/services/storage/transfer.py index 8c903655e4..9ae20e484e 100644 --- a/mars/services/storage/transfer.py +++ b/mars/services/storage/transfer.py @@ -207,10 +207,11 @@ async def send_batch_data( rest_keys.append(data_key) if local_buffers: - # for data that supports buffer protocol on both sides - # hand over to oscar to transfer data - await mo.copyto_via_buffers(local_buffers, remote_buffer_refs) - await receiver_ref.close_writers(session_id, copied_keys) + with mo.temp_transfer_block_size(block_size): + # for data that supports buffer protocol on both sides + # hand over to oscar to transfer data + await mo.copyto_via_buffers(local_buffers, remote_buffer_refs) + await receiver_ref.close_writers(session_id, copied_keys) else: rest_keys = to_send_keys rest_readers = readers