File tree 1 file changed +7
-4
lines changed
1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change 1
- #include < c10d/ProcessGroupMPI.hpp>
1
+ #define USE_C10D_MPI
2
+ #include < torch/csrc/distributed/c10d/Work.hpp>
3
+ #include < torch/csrc/distributed/c10d/ProcessGroup.hpp>
4
+ #include < torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
2
5
#include < torch/torch.h>
3
6
#include < iostream>
4
7
@@ -35,8 +38,8 @@ struct Model : torch::nn::Module {
35
38
};
36
39
37
40
void waitWork (
38
- std::shared_ptr <c10d::ProcessGroupMPI> pg,
39
- std::vector<std::shared_ptr <c10d::ProcessGroup ::Work>> works) {
41
+ c10::intrusive_ptr <c10d::ProcessGroupMPI> pg,
42
+ std::vector<c10::intrusive_ptr <c10d::Work>> works) {
40
43
for (auto & work : works) {
41
44
try {
42
45
work->wait ();
@@ -115,7 +118,7 @@ int main(int argc, char* argv[]) {
115
118
// since this synchronizes parameters after backward pass while DDP
116
119
// overlaps synchronizing parameters and computing gradients in backward
117
120
// pass
118
- std::vector<std::shared_ptr<::c10d::ProcessGroup ::Work>> works;
121
+ std::vector<::c10::intrusive_ptr<::c10d ::Work>> works;
119
122
for (auto & param : model->named_parameters ()) {
120
123
std::vector<torch::Tensor> tmp = {param.value ().grad ()};
121
124
auto work = pg->allreduce (tmp);
You can’t perform that action at this time.
0 commit comments