#include "heapsort.hpp" #include "utils.hpp" #include #include #include template class HeapSorter { public: using pointed = decltype(*std::declval()); using difftype = decltype(std::declval() - std::declval()); HeapSorter(T _begin, T _end, const std::function &_comparer) : begin(_begin), end(_end), comp(_comparer), len(_end - _begin) { } void doSort() { buildMaxHeap(); // std::cout << "after build heap" << std::endl; // arrayprint(begin, end); // std::cout << std::endl; for (difftype i = len - 1; i >= 1; i--) { const auto tmp = *begin; *begin = *(begin + i); *(begin + i) = tmp; maxHeapify(0, i); } } private: const T begin; const T end; const std::function ∁ const difftype len; void maxHeapify(difftype start, difftype bottom) { const auto _bottom = begin + bottom; auto current = begin + start; while (current < _bottom) { const auto diff = current - begin; const auto left = (diff * 2) + 1; const auto right = left + 1; if (left < bottom) { const auto _left = begin + left; if (comp(*current, *_left)) { if (right < bottom && comp(*_left, *(begin + right))) { const auto _right = begin+right; assert(_right != _bottom); const auto tmp = *_right; *_right = *current; *current = tmp; current = _right; } else { const auto tmp = *_left; assert(_left != _bottom); *_left = *current; *current = tmp; current = _left; } } else { if (right < bottom && comp(*current, *(begin+right))) { const auto _right= begin+right; assert(_right != _bottom); const auto tmp = *_right; *_right = *current; *current = tmp; current = _right; } else { break; } } } else { break; } } } void buildMaxHeap() { difftype i = len / 2 + 1; while (true) { maxHeapify(i, len); if (i == 0) break; i--; } } }; template void heapsort(T begin, T end, std::function()), decltype(*std::declval()))> comparer) { HeapSorter sorter(begin, end, comparer); sorter.doSort(); } template void heapsort(T begin, T end) { heapsort(begin, end, [](decltype(*std::declval()) f, decltype(*std::declval()) s){ return f < s; }); } int main() { //last value is to find OOB std::vector vec{12345, 1,5,3,2,5,7,9,-3,12,0, 9999}; arrayprint(vec.begin()+1, vec.end()-1); std::cout << std::endl; heapsort(vec.begin()+1, vec.end()-1); arrayprint(vec.begin()+1, vec.end()-1); std::cout << std::endl; std::cout << std::endl; std::cout << "sort by reverse iterators" << std::endl; arrayprint(vec.begin()+1, vec.end()-1); std::cout << std::endl; heapsort(vec.rbegin()+1, vec.rend()-1); arrayprint(vec.begin()+1, vec.end()-1); std::cout << std::endl; std::cout << std::endl; std::vector vec2; arrayprint(vec2.begin(), vec2.end()); std::cout << std::endl; heapsort(vec2.begin(), vec2.end()); arrayprint(vec2.begin(), vec2.end()); std::cout << std::endl; return 0; }