Files
DSA/sorting/heapsort.cpp
2024-10-04 16:11:55 -05:00

146 lines
4.5 KiB
C++

#include "heapsort.hpp"
#include "utils.hpp"
#include <iostream>
#include <vector>
#include <cassert>
template <std::random_access_iterator T>
class HeapSorter
{
public:
using pointed = decltype(*std::declval<T>());
using difftype = decltype(std::declval<T>() - std::declval<T>());
HeapSorter(T _begin, T _end, const std::function<bool(pointed, pointed)> &_comparer)
: begin(_begin), end(_end), comp(_comparer), len(_end - _begin)
{
}
void doSort()
{
buildMaxHeap();
// std::cout << "after build heap" << std::endl;
// arrayprint(begin, end);
// std::cout << std::endl;
for (difftype i = len - 1; i >= 1; i--)
{
const auto tmp = *begin;
*begin = *(begin + i);
*(begin + i) = tmp;
maxHeapify(0, i);
}
}
private:
const T begin;
const T end;
const std::function<bool(pointed, pointed)> &comp;
const difftype len;
void maxHeapify(difftype start, difftype bottom)
{
const auto _bottom = begin + bottom;
auto current = begin + start;
while (current < _bottom)
{
const auto diff = current - begin;
const auto left = (diff * 2) + 1;
const auto right = left + 1;
if (left < bottom)
{
const auto _left = begin + left;
if (comp(*current, *_left))
{
if (right < bottom && comp(*_left, *(begin + right)))
{
const auto _right = begin+right;
assert(_right != _bottom);
const auto tmp = *_right;
*_right = *current;
*current = tmp;
current = _right;
}
else
{
const auto tmp = *_left;
assert(_left != _bottom);
*_left = *current;
*current = tmp;
current = _left;
}
}
else
{
if (right < bottom && comp(*current, *(begin+right)))
{
const auto _right= begin+right;
assert(_right != _bottom);
const auto tmp = *_right;
*_right = *current;
*current = tmp;
current = _right;
}
else
{
break;
}
}
}
else
{
break;
}
}
}
void buildMaxHeap()
{
difftype i = len / 2 + 1;
while (true)
{
maxHeapify(i, len);
if (i == 0)
break;
i--;
}
}
};
template <std::random_access_iterator T>
void heapsort(T begin, T end, std::function<bool(decltype(*std::declval<T>()), decltype(*std::declval<T>()))> comparer)
{
HeapSorter sorter(begin, end, comparer);
sorter.doSort();
}
template <std::random_access_iterator T>
void heapsort(T begin, T end)
{
heapsort(begin, end, [](decltype(*std::declval<T>()) f, decltype(*std::declval<T>()) s){ return f < s; });
}
int main()
{
//last value is to find OOB
std::vector<int32_t> vec{12345, 1,5,3,2,5,7,9,-3,12,0, 9999};
arrayprint(vec.begin()+1, vec.end()-1);
std::cout << std::endl;
heapsort(vec.begin()+1, vec.end()-1);
arrayprint(vec.begin()+1, vec.end()-1);
std::cout << std::endl;
std::cout << std::endl;
std::cout << "sort by reverse iterators" << std::endl;
arrayprint(vec.begin()+1, vec.end()-1);
std::cout << std::endl;
heapsort(vec.rbegin()+1, vec.rend()-1);
arrayprint(vec.begin()+1, vec.end()-1);
std::cout << std::endl;
std::cout << std::endl;
std::vector<int32_t> vec2;
arrayprint(vec2.begin(), vec2.end());
std::cout << std::endl;
heapsort(vec2.begin(), vec2.end());
arrayprint(vec2.begin(), vec2.end());
std::cout << std::endl;
return 0;
}