Skip to content

Commit 9a05005

Browse files
committed
fix: fix custom_ar bug for rocm
1 parent 398bbaf commit 9a05005

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

rtp_llm/cpp/devices/rocm_impl/ROCmDistributedOp.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,11 @@ AllReduceOutput ROCmDevice::allReduce(const AllReduceParams& params) {
111111
if (use_custom_ar) {
112112
auto custom_ar_res_buf =
113113
allocateBuffer({buffer->type(), buffer->shape(), AllocationType::DEVICE}, {"custom_ar_buf"});
114+
printBufferData(*buffer, "ar_input_buffer");
114115
torch::Tensor input_tensor = Buffer2torchTensor(*buffer, false);
115116
torch::Tensor output_tensor = Buffer2torchTensor(*custom_ar_res_buf, false);
116117
custom_allreduce_comm_->allReduce(input_tensor, output_tensor);
118+
printBufferData(*custom_ar_res_buf, "ar_output_buffer_after");
117119
return AllReduceOutput{custom_ar_res_buf};
118120
}
119121

rtp_llm/cpp/rocm/custom_ar/custom_ar_comm.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ CustomAllReduceComm::~CustomAllReduceComm() {
2929

3030
bool CustomAllReduceComm::checkAllReduceAvailable(size_t elts_total_num, DataType data_type, size_t world_size) {
3131
size_t elts_total_size = elts_total_num * getTypeSize(data_type);
32-
3332
if (elts_total_size % 16 != 0) {
3433
return false;
3534
}
@@ -86,11 +85,10 @@ void CustomAllReduceComm::init(const NcclParam& nccl_para, hipStream_t stream) {
8685

8786
// meta data buffers need to be "uncached" for signal on MI200
8887
meta_ = aiter::allocate_meta_buffer(aiter::meta_size() + comm_buf_threshold_);
89-
buffer_ = torch::empty(
90-
{
91-
comm_buf_threshold_,
92-
},
93-
torch::dtype(torch::kUInt8).device(torch::kCUDA));
88+
void* raw_ptr;
89+
hipMalloc(&raw_ptr, comm_buf_threshold_);
90+
auto deleter = [](void* p) { hipFree(p); };
91+
buffer_ = torch::from_blob(raw_ptr, {comm_buf_threshold_}, deleter, torch::kCUDA);
9492
rank_data_ = torch::empty({16 * 1024 * 1024}, torch::dtype(torch::kUInt8).device(torch::kCUDA));
9593

9694
std::vector<torch::Tensor> meta_handles = prepareP2PBuffer_(nccl_para, meta_, stream);

0 commit comments

Comments
 (0)