diff --git a/cpp/distributed/dist-mnist.cpp b/cpp/distributed/dist-mnist.cpp index edbd816e59..579745683a 100644 --- a/cpp/distributed/dist-mnist.cpp +++ b/cpp/distributed/dist-mnist.cpp @@ -1,4 +1,7 @@ -#include +#define USE_C10D_MPI +#include +#include +#include #include #include @@ -35,8 +38,8 @@ struct Model : torch::nn::Module { }; void waitWork( - std::shared_ptr pg, - std::vector> works) { + c10::intrusive_ptr pg, + std::vector> works) { for (auto& work : works) { try { work->wait(); @@ -115,7 +118,7 @@ int main(int argc, char* argv[]) { // since this synchronizes parameters after backward pass while DDP // overlaps synchronizing parameters and computing gradients in backward // pass - std::vector> works; + std::vector<::c10::intrusive_ptr<::c10d::Work>> works; for (auto& param : model->named_parameters()) { std::vector tmp = {param.value().grad()}; auto work = pg->allreduce(tmp);