reduce_2.cpp 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. /*
  2. * Copyright 2018-present Facebook, Inc.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include <cassert>
  18. #include <exception>
  19. #include <iostream>
  20. #include <numeric>
  21. #include <vector>
  22. #include <folly/experimental/pushmi/examples/pool.h>
  23. #include <folly/experimental/pushmi/examples/reduce.h>
  24. using namespace pushmi::aliases;
  25. template <class Executor, class Allocator = std::allocator<char>>
  26. auto naive_executor_bulk_target(Executor e, Allocator a = Allocator{}) {
  27. return [e, a](
  28. auto init,
  29. auto selector,
  30. auto input,
  31. auto&& func,
  32. auto sb,
  33. auto se,
  34. auto out) {
  35. using RS = decltype(selector);
  36. using F = std::conditional_t<
  37. std::is_lvalue_reference<decltype(func)>::value,
  38. decltype(func),
  39. typename std::remove_reference<decltype(func)>::type>;
  40. using Out = decltype(out);
  41. try {
  42. typename std::allocator_traits<Allocator>::template rebind_alloc<char>
  43. allocState(a);
  44. auto shared_state = std::allocate_shared<std::tuple<
  45. std::exception_ptr, // first exception
  46. Out, // destination
  47. RS, // selector
  48. F, // func
  49. std::atomic<decltype(init(input))>, // accumulation
  50. std::atomic<std::size_t>, // pending
  51. std::atomic<std::size_t> // exception count (protects assignment to
  52. // first exception)
  53. >>(
  54. allocState,
  55. std::exception_ptr{},
  56. std::move(out),
  57. std::move(selector),
  58. (decltype(func)&&)func,
  59. init(std::move(input)),
  60. 1,
  61. 0);
  62. e | op::submit([e, sb, se, shared_state](auto) {
  63. auto stepDone = [](auto shared_state) {
  64. // pending
  65. if (--std::get<5>(*shared_state) == 0) {
  66. // first exception
  67. if (std::get<0>(*shared_state)) {
  68. mi::set_error(
  69. std::get<1>(*shared_state), std::get<0>(*shared_state));
  70. return;
  71. }
  72. try {
  73. // selector(accumulation)
  74. auto result = std::get<2>(*shared_state)(
  75. std::move(std::get<4>(*shared_state).load()));
  76. mi::set_value(std::get<1>(*shared_state), std::move(result));
  77. } catch (...) {
  78. mi::set_error(
  79. std::get<1>(*shared_state), std::current_exception());
  80. }
  81. }
  82. };
  83. for (decltype(sb) idx{sb}; idx != se;
  84. ++idx, ++std::get<5>(*shared_state)) {
  85. e | op::submit([shared_state, idx, stepDone](auto ex) {
  86. try {
  87. // this indicates to me that bulk is not the right abstraction
  88. auto old = std::get<4>(*shared_state).load();
  89. auto step = old;
  90. do {
  91. step = old;
  92. // func(accumulation, idx)
  93. std::get<3> (*shared_state)(step, idx);
  94. } while (!std::get<4>(*shared_state)
  95. .compare_exchange_strong(old, step));
  96. } catch (...) {
  97. // exception count
  98. if (std::get<6>(*shared_state)++ == 0) {
  99. // store first exception
  100. std::get<0>(*shared_state) = std::current_exception();
  101. } // else eat the exception
  102. }
  103. stepDone(shared_state);
  104. });
  105. }
  106. stepDone(shared_state);
  107. });
  108. } catch (...) {
  109. e |
  110. op::submit([out = std::move(out), ep = std::current_exception()](
  111. auto) mutable { mi::set_error(out, ep); });
  112. }
  113. };
  114. }
  115. int main() {
  116. mi::pool p{std::max(1u, std::thread::hardware_concurrency())};
  117. std::vector<int> vec(10);
  118. std::fill(vec.begin(), vec.end(), 4);
  119. auto fortyTwo = mi::reduce(
  120. naive_executor_bulk_target(p.executor()),
  121. vec.begin(),
  122. vec.end(),
  123. 2,
  124. std::plus<>{});
  125. assert(std::accumulate(vec.begin(), vec.end(), 2) == fortyTwo);
  126. std::cout << "OK" << std::endl;
  127. p.wait();
  128. }