Skip to content

Instantly share code, notes, and snippets.

@JiaxiangZheng
Created March 4, 2025 16:25
Show Gist options
  • Save JiaxiangZheng/68f7288aa3aae6302f1b293df70fd368 to your computer and use it in GitHub Desktop.
Save JiaxiangZheng/68f7288aa3aae6302f1b293df70fd368 to your computer and use it in GitHub Desktop.
bpe implementation
#include <algorithm>
#include <chrono>
#include <fstream>
#include <future>
#include <iostream>
#include <sstream>
#include <string>
#include <string_view>
#include <thread>
#include <unordered_map>
#include <vector>
using namespace std;
using Word = vector<string>;
using Corpus = unordered_map<string, int>;
#if 0
void get_pairs(const vector<Word> &corpus,
std::unordered_map<string, int> &pairs) {
// unordered_map<string, int> pairs;
for (const auto &word : corpus) {
for (size_t i = 0; i < word.size() - 1; ++i) {
string pair = word[i] + " " + word[i + 1];
pairs[pair]++;
}
}
return;
}
#else
void get_pairs_single(const std::vector<Word> &corpus,
std::unordered_map<std::string, int> &pairs, size_t start,
size_t end) {
std::unordered_map<std::string, int> local_pairs;
local_pairs.reserve(1000);
for (size_t i = start; i < end && i < corpus.size(); ++i) {
const auto &word = corpus[i];
if (word.size() < 2)
continue;
for (size_t j = 0; j < word.size() - 1; ++j) {
std::string pair = word[j];
pair.reserve(pair.size() + word[j + 1].size() + 1);
pair += " ";
pair += word[j + 1];
local_pairs[pair]++;
}
}
for (const auto &[pair, count] : local_pairs) {
pairs[pair] += count;
}
}
void get_pairs(const std::vector<Word> &corpus,
std::unordered_map<std::string, int> &pairs) {
pairs.clear();
pairs.reserve(corpus.size() / 2);
size_t num_threads = std::thread::hardware_concurrency();
if (num_threads <= 1 || corpus.size() < 1000) {
get_pairs_single(corpus, pairs, 0, corpus.size());
return;
}
size_t batch_size = (corpus.size() + num_threads - 1) / num_threads;
std::vector<std::future<void>> futures;
std::mutex merge_mutex;
for (size_t t = 0; t < num_threads; ++t) {
size_t start = t * batch_size;
size_t end = std::min(start + batch_size, corpus.size());
if (start >= corpus.size())
break;
futures.push_back(std::async(
std::launch::async, [&corpus, &pairs, &merge_mutex, start, end]() {
std::unordered_map<std::string, int> local_pairs;
get_pairs_single(corpus, local_pairs, start, end);
std::lock_guard<std::mutex> lock(merge_mutex);
for (const auto &[pair, count] : local_pairs) {
pairs[pair] += count;
}
}));
}
for (auto &f : futures) {
f.get();
}
}
#endif
Word merge_pair(const Word &word, const std::string &pair) {
if (word.size() < 2)
return word;
size_t space_pos = pair.find(' ');
std::string pair_first = pair.substr(0, space_pos);
std::string pair_second = pair.substr(space_pos + 1);
std::string merged = std::string(pair_first) + std::string(pair_second);
Word new_word;
new_word.reserve(word.size());
for (size_t i = 0; i < word.size();) {
if (i + 1 < word.size() && word[i] == pair_first &&
word[i + 1] == pair_second) {
new_word.push_back(merged);
i += 2;
} else {
new_word.push_back(word[i]);
i++;
}
}
return new_word;
}
void merge_pairs_batch(
std::vector<Word> &corpus, const std::string &pair,
size_t num_threads = std::thread::hardware_concurrency()) {
if (corpus.empty())
return;
num_threads = std::min(num_threads, corpus.size());
if (num_threads <= 1) {
for (auto &word : corpus) {
word = merge_pair(word, pair);
}
return;
}
size_t batch_size =
(corpus.size() + num_threads - 1) / num_threads;
std::vector<std::future<void>> futures;
// 启动异步任务
for (size_t t = 0; t < num_threads; ++t) {
size_t start = t * batch_size;
size_t end = std::min(start + batch_size, corpus.size());
if (start >= corpus.size())
break;
futures.push_back(
std::async(std::launch::async, [&corpus, &pair, start, end]() {
for (size_t i = start; i < end; ++i) {
corpus[i] = merge_pair(corpus[i], pair);
}
}));
}
for (auto &f : futures) {
f.get();
}
}
pair<vector<string>, unordered_map<string, Word>>
train_bpe(const string &filename, int vocab_size) {
ifstream file(filename);
if (!file) {
cerr << "unable to open file: " << filename << endl;
return {};
}
Corpus raw_corpus;
string line, word;
while (getline(file, line)) {
istringstream iss(line);
while (iss >> word) {
raw_corpus[word + "</w>"]++;
}
}
file.close();
// 将单词拆成字符序列
vector<Word> corpus;
for (const auto &[word, freq] : raw_corpus) {
Word tokens;
for (char c : word) {
tokens.push_back(string(1, c));
}
for (int i = 0; i < freq; ++i) {
corpus.push_back(tokens);
}
}
vector<string> vocab = {"</w>"};
for (char c = 'a'; c <= 'z'; ++c)
vocab.push_back(string(1, c));
for (char c = 'A'; c <= 'Z'; ++c)
vocab.push_back(string(1, c));
// BPE 迭代
std::unordered_map<string, int> pairs;
pairs.reserve(5000);
std::cout << "corpus size: " << corpus.size() << std::endl;
while (vocab.size() < static_cast<size_t>(vocab_size)) {
auto start = std::chrono::high_resolution_clock::now();
pairs.clear();
get_pairs(corpus, pairs);
auto end = std::chrono::high_resolution_clock::now();
if (pairs.empty())
break;
// 找到频率最高的字符对
auto max_pair = max_element(
pairs.begin(), pairs.end(),
[](const auto &a, const auto &b) { return a.second < b.second; });
string best_pair = max_pair->first;
string merged = best_pair;
merged.erase(remove(merged.begin(), merged.end(), ' '),
merged.end());
vocab.push_back(merged);
cout << "best_pair: " << max_pair->first << " " << max_pair->second << "\t"
<< "vocab size: " << vocab.size() << endl;
start = std::chrono::high_resolution_clock::now();
merge_pairs_batch(corpus, best_pair);
end = std::chrono::high_resolution_clock::now();
}
// TODO:
unordered_map<string, Word> token_map;
for (const auto &[word, freq] : raw_corpus) {
Word tokens;
for (char c : word) {
tokens.push_back(string(1, c));
}
for (const auto &merged : vocab) {
tokens = merge_pair(tokens, merged + " " + merged);
}
token_map[word] = tokens;
}
return {vocab, token_map};
}
int main() {
string filename =
"C:/Users/jiaxi/Downloads/t8.shakespeare.txt";
int vocab_size = 5000;
auto [vocab, token_map] = train_bpe(filename, vocab_size);
cout << "Vocab Size (" << vocab.size() << ")" << endl;
for (const auto &token : vocab) {
cout << token << endl;
}
cout << "\nTokenized Words:" << endl;
int count = 0;
for (const auto &[word, tokens] : token_map) {
cout << word << " -> ";
for (const auto &t : tokens) {
cout << t << " ";
}
cout << endl;
if (++count >= 5)
break;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment