| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #define PY_SSIZE_T_CLEAN |
| #include <Python.h> |
| #include <vector> |
| #include <string> |
| #include <map> |
| #include <iostream> |
| #include <fstream> |
| #include <algorithm> |
| #include <cstdint> |
| #include <chrono> |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| struct TrieNode { |
| int token_id = -1; |
| std::map<unsigned char, TrieNode*> children; |
| |
| ~TrieNode() { |
| for (auto& [key, child] : children) { |
| delete child; |
| } |
| } |
| }; |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| class DATCompiler { |
| private: |
| std::vector<int32_t> base; |
| std::vector<int32_t> check; |
| std::vector<int32_t> values; |
| |
| int32_t max_size = 0; |
| int32_t nodes_used = 0; |
| |
| |
| size_t vocab_size = 0; |
| double build_time_ms = 0.0; |
|
|
| public: |
| DATCompiler() { |
| |
| resize(500000); |
| base[0] = 1; |
| } |
| |
| |
| |
| |
| void resize(int32_t new_size) { |
| if (new_size <= max_size) return; |
| |
| base.resize(new_size, 0); |
| check.resize(new_size, -1); |
| values.resize(new_size, -1); |
| max_size = new_size; |
| } |
| |
| |
| |
| |
| void insert_trie(TrieNode* root, const std::string& key, int token_id) { |
| TrieNode* current = root; |
| |
| for (unsigned char byte : key) { |
| auto it = current->children.find(byte); |
| if (it == current->children.end()) { |
| current->children[byte] = new TrieNode(); |
| } |
| current = current->children[byte]; |
| } |
| |
| current->token_id = token_id; |
| } |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| int32_t find_base(const std::vector<unsigned char>& children) { |
| int32_t b = 1; |
| |
| while (true) { |
| bool collision = false; |
| |
| |
| for (unsigned char c : children) { |
| int32_t idx = b + static_cast<int32_t>(c); |
| |
| |
| if (idx >= max_size) { |
| resize(idx + 512); |
| } |
| |
| |
| if (check[idx] != -1) { |
| collision = true; |
| break; |
| } |
| } |
| |
| if (!collision) { |
| return b; |
| } |
| |
| b++; |
| } |
| } |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| void build_dat(TrieNode* node, int32_t dat_index) { |
| if (node->children.empty()) { |
| return; |
| } |
| |
| |
| std::vector<unsigned char> chars; |
| chars.reserve(node->children.size()); |
| |
| for (const auto& [byte, child] : node->children) { |
| chars.push_back(byte); |
| } |
| |
| |
| int32_t b = find_base(chars); |
| base[dat_index] = b; |
| |
| |
| for (unsigned char c : chars) { |
| int32_t child_idx = b + static_cast<int32_t>(c); |
| |
| |
| check[child_idx] = dat_index; |
| nodes_used = std::max(nodes_used, child_idx + 1); |
| |
| |
| TrieNode* child_node = node->children[c]; |
| if (child_node->token_id != -1) { |
| values[child_idx] = child_node->token_id; |
| } |
| } |
| |
| |
| for (unsigned char c : chars) { |
| int32_t child_idx = b + static_cast<int32_t>(c); |
| build_dat(node->children[c], child_idx); |
| } |
| } |
| |
| |
| |
| |
| void save(const std::string& filename) { |
| |
| int32_t real_size = nodes_used; |
| while (real_size > 0 && check[real_size - 1] == -1) { |
| real_size--; |
| } |
| real_size++; |
| |
| std::ofstream out(filename, std::ios::binary); |
| if (!out.is_open()) { |
| std::cerr << "[C++ Compiler] ERROR: Cannot open file: " << filename << std::endl; |
| return; |
| } |
| |
| |
| uint32_t magic = 0x59415243; |
| uint32_t version = 2; |
| |
| out.write(reinterpret_cast<char*>(&magic), 4); |
| out.write(reinterpret_cast<char*>(&version), 4); |
| out.write(reinterpret_cast<char*>(&real_size), 4); |
| |
| |
| out.write(reinterpret_cast<char*>(base.data()), real_size * sizeof(int32_t)); |
| out.write(reinterpret_cast<char*>(check.data()), real_size * sizeof(int32_t)); |
| out.write(reinterpret_cast<char*>(values.data()), real_size * sizeof(int32_t)); |
| |
| out.close(); |
| |
| std::cout << " [C++ Compiler] Saved DAT: " << real_size << " nodes, " |
| << (real_size * 12 / 1024) << " KB" << std::endl; |
| } |
| |
| |
| |
| |
| |
| |
| |
| void compile(const std::vector<std::string>& vocab, const std::string& out_file) { |
| auto start_time = std::chrono::high_resolution_clock::now(); |
| |
| vocab_size = vocab.size(); |
| std::cout << " [C++ Compiler] Building trie from " << vocab_size << " tokens..." << std::endl; |
| |
| |
| TrieNode* root = new TrieNode(); |
| |
| for (size_t i = 0; i < vocab.size(); ++i) { |
| insert_trie(root, vocab[i], static_cast<int>(i)); |
| } |
| |
| |
| |
| check[0] = -1; |
| nodes_used = 1; |
| |
| std::cout << " [C++ Compiler] Converting trie to DAT..." << std::endl; |
| build_dat(root, 0); |
| |
| |
| save(out_file); |
| |
| |
| delete root; |
| |
| auto end_time = std::chrono::high_resolution_clock::now(); |
| build_time_ms = std::chrono::duration<double, std::milli>(end_time - start_time).count(); |
| |
| std::cout << " [C++ Compiler] Complete in " << build_time_ms << " ms" << std::endl; |
| } |
| |
| |
| int32_t get_node_count() const { return nodes_used; } |
| double get_build_time_ms() const { return build_time_ms; } |
| }; |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| static PyObject* compile_dat(PyObject* self, PyObject* args) { |
| PyObject* vocab_list; |
| const char* out_path; |
| |
| if (!PyArg_ParseTuple(args, "Os", &vocab_list, &out_path)) { |
| return NULL; |
| } |
| |
| |
| if (!PyList_Check(vocab_list)) { |
| PyErr_SetString(PyExc_TypeError, "vocab must be a list"); |
| return NULL; |
| } |
| |
| |
| Py_ssize_t len = PyList_Size(vocab_list); |
| std::vector<std::string> vocab; |
| vocab.reserve(len); |
| |
| for (Py_ssize_t i = 0; i < len; ++i) { |
| PyObject* item = PyList_GetItem(vocab_list, i); |
| |
| if (!PyUnicode_Check(item)) { |
| |
| continue; |
| } |
| |
| |
| const char* str = PyUnicode_AsUTF8(item); |
| if (str) { |
| vocab.push_back(std::string(str)); |
| } |
| } |
| |
| |
| Py_BEGIN_ALLOW_THREADS |
| |
| |
| DATCompiler compiler; |
| compiler.compile(vocab, std::string(out_path)); |
| |
| Py_END_ALLOW_THREADS |
| |
| |
| PyObject* result = PyDict_New(); |
| PyDict_SetItemString(result, "vocab_size", PyLong_FromLong(static_cast<long>(vocab.size()))); |
| PyDict_SetItemString(result, "node_count", PyLong_FromLong(0)); |
| PyDict_SetItemString(result, "output_path", PyUnicode_FromString(out_path)); |
| |
| return result; |
| } |
|
|
|
|
| |
| |
| |
| static PyObject* get_version(PyObject* self, PyObject* args) { |
| return PyUnicode_FromString("2.0.0-hyperfast"); |
| } |
|
|
|
|
| |
| |
| |
|
|
| static PyMethodDef CompilerMethods[] = { |
| { |
| "compile_dat", |
| compile_dat, |
| METH_VARARGS, |
| "Fast C++ DAT Compiler.\n\n" |
| "Args:\n" |
| " vocab (List[str]): Vocabulary strings in order\n" |
| " output_path (str): Path to write .dat file\n\n" |
| "Returns:\n" |
| " dict: Compilation statistics\n\n" |
| "Example:\n" |
| " >>> from crayon.c_ext import crayon_compiler\n" |
| " >>> crayon_compiler.compile_dat(['hello', 'world'], 'vocab.dat')\n" |
| }, |
| { |
| "get_version", |
| get_version, |
| METH_NOARGS, |
| "Get compiler version string." |
| }, |
| {NULL, NULL, 0, NULL} |
| }; |
|
|
| static struct PyModuleDef compiler_module = { |
| PyModuleDef_HEAD_INIT, |
| "crayon_compiler", |
| "CRAYON Fast DAT Compiler\n\n" |
| "Converts vocabulary lists to Double-Array Trie binaries.\n" |
| "~500x faster than Python implementation.\n\n" |
| "Author: XERV AI Research\n" |
| "Version: 2.0.0", |
| -1, |
| CompilerMethods |
| }; |
|
|
|
|
| PyMODINIT_FUNC PyInit_crayon_compiler(void) { |
| return PyModule_Create(&compiler_module); |
| } |
|
|