N要素の分岐しない odd-even mergesort (GCC+Linux/x86_64用)
(こちらの記事の続き)
(ジェネレータを生成) % g++ -O2 -Wall gen_sort.cpp -o gen_sort (8要素向けのソート関数を作成) % ./gen_sort 0 8 > oem_sort_8.s % gcc -c oem_sort8.s (ちゃんと作られたようだ) % nm oem_sort_8.o 0000000000000000 T odd_even_sort8 (テストコード書き) % cat > test.c #include <assert.h> void odd_even_sort8(unsigned int* d); int main() { int i; unsigned int s[] = {7,6,5,4,1,2,3,4}; odd_even_sort8(s); for(i = 0; i < 7; ++i) { assert(s[i] <= s[i + 1]); } return 0; } ^D (テストコードをコンパイル、ソート関数とリンク) % gcc -Wall test.c oem_sort_8.o (無事実行) % ./a.out % (ジェネレータの使いかた) % ./gen_sort usage: ./gen_sort [0|1] N 0: odd-even sorting, 1: bitonic sorting N: 4, 8, 16, 32, 64, ...
以下ソースコード。
// gen_sort.cpp // こんなものをC++で書くのは狂気の沙汰ですね.. #include <string> #include <sstream> #include <cstddef> #include <cassert> #include <iostream> #include <boost/bind.hpp> #include <boost/array.hpp> #include <boost/utility.hpp> #include <boost/function.hpp> struct Reg { const char* const reg; const char* const reg_clob; const bool callee_save; }; class x86_64 : boost::noncopyable { public: explicit x86_64(std::ostream& ost) : ost_(ost) {} template<typename T> int out_asm(std::size_t n, T algo) { int count = 0; out_prologue(n, algo); algo(n, boost::bind(&x86_64::out_cmp_swap, this, _1, _2, _3, boost::ref(count))); out_epilogue(n); return count; } private: static const char* const SRC_REG; static const boost::array<Reg, 9> REGS; static const boost::array<Reg, 3> TMP_REGS; std::ostream& ost_; // callback void out_cmp_swap(std::size_t a, std::size_t b, int direction, int& count) { std::string mem1 = a < REGS.size() ? REGS[a].reg : get_mem_location(a); std::string mem2 = b < REGS.size() ? REGS[b].reg : get_mem_location(b); const char* const cc = direction == 1 ? "a " : "b "; if (a < REGS.size() && b < REGS.size()) { ost_ << "\tmov " << mem2 << ", " << TMP_REGS[0].reg << std::endl; ost_ << "\tcmp " << mem2 << ", " << mem1 << std::endl; ost_ << "\tcmov" << cc << mem1 << ", " << mem2 << std::endl; ost_ << "\tcmov" << cc << TMP_REGS[0].reg << ", " << mem1 << std::endl; } else { ost_ << "\tmov " << mem2 << ", " << TMP_REGS[0].reg << std::endl; ost_ << "\tmov " << mem1 << ", " << TMP_REGS[1].reg << std::endl; ost_ << "\tcmp " << TMP_REGS[0].reg << ", " << TMP_REGS[1].reg << std::endl; ost_ << "\tcmov" << cc << TMP_REGS[1].reg << ", " << TMP_REGS[2].reg << std::endl; ost_ << "\tcmov" << cc << TMP_REGS[0].reg << ", " << TMP_REGS[1].reg << std::endl; ost_ << "\tcmov" << cc << TMP_REGS[2].reg << ", " << TMP_REGS[0].reg << std::endl; ost_ << "\tmov " << TMP_REGS[0].reg << ", " << mem2 << std::endl; ost_ << "\tmov " << TMP_REGS[1].reg << ", " << mem1 << std::endl; } ++count; } void push_clobber_regs(std::size_t n) { for(size_t i = 0; i < n; ++i) { if (i < REGS.size() && REGS[i].callee_save) { ost_ << "\tpushq " << REGS[i].reg_clob << std::endl; } } if (n > 8) { for(size_t i = 0; i < TMP_REGS.size() ; ++i) { ost_ << "\tpushq " << TMP_REGS[i].reg_clob << std::endl; } } else { ost_ << "\tpushq " << TMP_REGS[0].reg_clob << std::endl; } return; } void pop_clobber_regs(std::size_t n) { if (n > 8) { for(size_t i = TMP_REGS.size(); i > 0; --i) { ost_ << "\tpopq " << TMP_REGS[i - 1].reg_clob << std::endl; } } else { ost_ << "\tpopq " << TMP_REGS[0].reg_clob << std::endl; } for(size_t i = n; i > 0; --i) { if (i - 1 < REGS.size() && REGS[i - 1].callee_save) { ost_ << "\tpopq " << REGS[i - 1].reg_clob << std::endl; } } return; } void load_regs(std::size_t n) { for(size_t i = 0; i < n; ++i) { if (i < REGS.size()) { ost_ << "\tmovl " << get_mem_location(i) << ", " << REGS[i].reg << std::endl; } } } void store_regs(std::size_t n) { for(size_t i = 0; i < n; ++i) { if (i < REGS.size()) { ost_ << "\tmovl " << REGS[i].reg << ", " << get_mem_location(i) << std::endl; } } } template <typename T> void out_prologue(std::size_t n, T algo) { ost_ << ".text" << std::endl << ".globl " << algo(n) << std::endl << algo(n) << ":" << std::endl; push_clobber_regs(n); load_regs(n); } void out_epilogue(std::size_t n) { store_regs(n); pop_clobber_regs(n); ost_ << "\tretq" << std::endl; // leaveq 不要 } std::string get_mem_location(std::size_t i) { std::ostringstream ret; if (i > 0) { ret << 4 * i; /* 4 == sizeof(unsigned int) on x86_64 */ } ret << "(" << SRC_REG << ")"; return ret.str(); } }; const char* const x86_64::SRC_REG = "%rdi"; const boost::array<Reg, 9> x86_64::REGS = {{ { "%eax", "%eax", false}, { "%esi", "%esi", false}, { "%edx", "%edx", false}, { "%ecx", "%ecx", false}, { "%r8d", "%r8" , false}, { "%r9d", "%r9" , false}, { "%r10d", "%r10", false}, { "%r11d", "%r11", false}, { "%r12d", "%r12", true}, }}; const boost::array<Reg, 3> x86_64::TMP_REGS = {{ { "%r13d", "%r13", true}, { "%r14d", "%r14", true}, { "%r15d", "%r15", true}, }}; class odd_even_mergesort { // 参考: http://www.inf.fh-flensburg.de/lang/algorithmen/sortieren/networks/oemen.htm public: void operator() (std::size_t n, boost::function<void (std::size_t, std::size_t, int)> out) { assert(n > 0 && n <= (sizeof(dmy_) / sizeof(dmy_[0])) && __builtin_popcount(n) == 1); mergesort(dmy_, n, out); } std::string operator() (std::size_t n) { std::stringstream ss; ss << "odd_even_sort" << n; return ss.str(); } private: static unsigned int dmy_[2 * 1024 * 1024]; // 2M要素まで template <typename T> void compare_and_swap(unsigned int* a, unsigned int* b, T out) { // fprintf(stderr, "# cmp %td, %td\n", a - dmy_, b - dmy_); out(a - dmy_, b - dmy_, 1); } template <typename T> void merge(unsigned int* d, int n, int skip, T out) { if (n > 2) { merge(d, n / 2, skip * 2, out); merge(d + skip, n / 2, skip * 2, out); for(int i = 1; i <= n - 3; i += 2) { compare_and_swap(&d[i * skip], &d[(i + 1) * skip], out); } } else { compare_and_swap(&d[0], &d[skip], out); } } template <typename T> void mergesort(unsigned int* d, int n, T out) { if (n > 1) { mergesort(d, n / 2, out); mergesort(d + n / 2, n / 2, out); merge(d, n, 1, out); } } }; unsigned int odd_even_mergesort::dmy_[]; class bitonic_sort { // おまけ (参考: http://jyoken.net/2005/kenpatsu/enari_oraf/) public: void operator() (std::size_t n, boost::function<void (std::size_t, std::size_t, int)> out) { std::size_t cnt, compare_place, bs, interval; int direction; assert(n > 0 && __builtin_popcount(n) == 1); for(bs = 2;bs <= n;bs *= 2) { direction = 1; n += n % bs; for (cnt = 0 ;cnt < n / bs; cnt++) { for (interval = bs / 2; interval >= 1; interval = interval / 2) { for (compare_place = 0; compare_place + interval < bs; compare_place++) { out (bs * cnt + compare_place, bs * cnt + compare_place + interval, direction); if (interval == compare_place + 1) { compare_place += interval; } } } direction *= -1; } } } std::string operator() (std::size_t n) { std::stringstream ss; ss << "bitonic_sort" << n; return ss.str(); } }; int main(int argc, char** argv) { int type = argc > 2 ? std::atoi(argv[1]) : 0; std::size_t n = argc > 2 ? std::atoi(argv[2]) : 0; if (!(argc > 2) || !(type == 0 || type == 1) || n < 4 || __builtin_popcount(n) != 1) { std::cout << "usage: " << argv[0] << " [0|1] N" << std::endl << "0: odd-even sorting, 1: bitonic sorting" << std::endl << "N: 4, 8, 16, 32, 64, ..." << std::endl; return 1; } int count = 0; x86_64 x(std::cout); switch(type) { case 0: count = x.out_asm(n, odd_even_mergesort()); break; case 1: count = x.out_asm(n, bitonic_sort()); break; } std::cout << "# " << n << " elements, " << count << " comparators" << std::endl; return 0; }