Created
July 31, 2017 18:34
-
-
Save egordm/42b59b37ad3a8d2a229acbbe0f0a2479 to your computer and use it in GitHub Desktop.
Python osu!.db read and filter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from struct import unpack_from | |
class DatabaseReader: | |
def __init__(self, file): | |
self.cursor = 0 | |
self._db = file.read() | |
def read_num(self, length): | |
type_map = {1: 'B', 2: 'H', 4: 'I', 8: 'Q'} | |
return self.__read_b(type_map[length], length) | |
def read_date(self): | |
ret = self.__read_b('Q', 8) | |
return (ret / 10000) - 62135769600000 | |
def read_float(self, length): | |
type_map = {4: 'f', 8: 'd'} | |
return self.__read_b(type_map[length], length) | |
def read_bool(self): | |
return self.__read_b('b', 1) != 0x00 | |
def read_string(self): | |
not_empty = self.__read_b('b', 1) | |
if not_empty == 0x00: return '' | |
length = self.__decode_leb128() | |
ret = self.__read_b(str(length) + 's', length) | |
try: | |
return ret.decode('utf-8') | |
except UnicodeDecodeError: | |
print("Invalid UTF-8 string. Returning empty string.") | |
return '' | |
def skip(self, jump): | |
self.cursor += jump | |
def skip_string(self): | |
not_empty = self.__read_b('b', 1) | |
if not_empty == 0x00: return | |
length = self.__decode_leb128() | |
self.skip(length) | |
def __read_b(self, val_type, length): | |
value = unpack_from(val_type, self._db, self.cursor)[0] | |
self.cursor += length | |
return value | |
def __decode_leb128(self): | |
ret = shift = 0 | |
while True: | |
byte = self.__read_b('B', 1) | |
ret |= ((byte & 0x7F) << shift) | |
if (byte & (1 << 7)) == 0: | |
break | |
shift += 7 | |
return ret |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator | |
import re | |
OPERATORS = {'~': operator.contains, '=': operator.eq, '>': operator.gt, '<': operator.lt, '>=': operator.ge, | |
'<=': operator.le} | |
OP_KEYS = [key for key in OPERATORS.keys()] | |
OP_KEYS.sort(key=len, reverse=True) | |
OPERATOR_REGEX = r'''({0}|\*|\n)'''.format('|'.join(OP_KEYS)) | |
def filter_data(data: list, filters) -> list: | |
def filter_item(item, filters): | |
for f in filters: | |
if not f.check(item): | |
return False | |
return True | |
return [item for item in data if filter_item(item, filters)] | |
class Filter: | |
def __init__(self, field, operator, value): | |
self.field = field | |
self.operator = operator | |
if value.isdigit() or is_real(value): | |
self.value = eval(value) | |
else: | |
self.value = str(value).replace('\'', '').replace('"', '') | |
def check(self, data): | |
return self.operator(getattr(data, self.field), self.value) | |
def __str__(self) -> str: | |
op_str = '?' | |
for op in OPERATORS.keys(): | |
if OPERATORS[op] == self.operator: op_str = op | |
return 'Filter: {0}{1}{2}'.format(self.field, op_str, self.value) | |
def is_real(txt): | |
try: | |
float(txt) | |
return True | |
except ValueError: | |
return False | |
def parse_filters(text): | |
# Split string by ' ' but keep spaces in strings enclosed in "" | |
raw_filters = re.split(''' (?=(?:[^']|'[^']*')*$)''', text) | |
ret = [] | |
for raw_filter in raw_filters: | |
seperated = list(filter(None, re.split(OPERATOR_REGEX, raw_filter))) | |
if len(seperated) != 3: continue | |
ret.append(Filter(seperated[0], OPERATORS[seperated[1]], seperated[2])) | |
return ret |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Loads the maps | |
from tools.database_reader import DatabaseReader | |
from tools.osu_models import Beatmap, TimingPoint | |
from tools.filter_tools import filter_data, parse_filters | |
class OsuDB: | |
version = 0 | |
user = 'Unknown' | |
beatmaps = [] | |
def __init__(self, file_path): | |
self.path = file_path | |
self.read() | |
def read(self): | |
with open(self.path, 'rb') as file: | |
reader = DatabaseReader(file) | |
self.version = reader.read_num(4) | |
reader.skip(13) | |
self.user = reader.read_string() | |
num_beatmaps = reader.read_num(4) | |
print('Reading {}\'s database. Expecting {} maps.'.format(self.user, num_beatmaps)) | |
for _ in range(num_beatmaps): | |
bm = read_beatmap(reader, self.version) | |
if bm is not None: | |
self.beatmaps.append(bm) | |
print('Loaded {}/{} maps.'.format(len(self.beatmaps), num_beatmaps)) | |
def filter(self, fs): | |
filters = parse_filters(fs) | |
print([str(filter) for filter in filters]) | |
return filter_data(self.beatmaps, filters) | |
def read_beatmap(reader: DatabaseReader, version): | |
plan_b = reader.read_num(4) + reader.cursor | |
try: | |
ret = Beatmap() | |
ret.artist = reader.read_string() | |
reader.skip_string() | |
ret.title = reader.read_string() | |
reader.skip_string() | |
ret.creator = reader.read_string() | |
ret.version = reader.read_string() | |
ret.audio_file = reader.read_string() | |
reader.skip_string() | |
ret.osu_file = reader.read_string() | |
ret.ranked = reader.read_num(1) | |
reader.skip(14) | |
ret.ar = reader.read_float(4) | |
ret.cs = reader.read_float(4) | |
ret.hp = reader.read_float(4) | |
ret.od = reader.read_float(4) | |
reader.skip(8) | |
# diffs | |
if version >= 20140609: | |
ret.std_diffs = read_diff_pairs(reader) | |
ret.taiko_diffs = read_diff_pairs(reader) | |
ret.ctb_diffs = read_diff_pairs(reader) | |
ret.mania_diffs = read_diff_pairs(reader) | |
ret.time_drain = reader.read_num(4) | |
ret.time_total = reader.read_num(4) | |
reader.skip(4) | |
# Timing Points | |
n_tps = reader.read_num(4) | |
ret.timingpoints = [read_timing_point(reader) for _ in range(n_tps)] | |
ret.beatmap_id = reader.read_num(4) | |
ret.set_id = reader.read_num(4) | |
reader.skip(14) | |
ret.mode = reader.read_num(1) | |
reader.skip_string() | |
reader.skip_string() | |
reader.skip(2) | |
reader.skip_string() | |
reader.skip(10) | |
ret.folder_name = reader.read_string() | |
reader.skip(18) | |
if reader.cursor != plan_b: raise Exception('Offsets are not equal. Entry corrupted?') | |
return ret | |
except Exception as e: | |
print('Ripperoni ' + str(e)) | |
reader.cursor = plan_b | |
return None | |
def read_diff_pairs(reader: DatabaseReader): | |
ret = {} | |
n = reader.read_num(4) | |
for _ in range(n): | |
reader.read_num(1) | |
mod = reader.read_num(4) | |
reader.read_num(1) | |
rating = reader.read_float(8) | |
ret[mod] = rating | |
return ret | |
def read_timing_point(reader: DatabaseReader): | |
ret = TimingPoint() | |
ret.mpb = reader.read_float(8) | |
ret.offset = reader.read_float(8) | |
ret.inherited = reader.read_bool() | |
return ret |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Beatmap(object): | |
__slots__ = ['artist', 'title', 'creator', 'version', 'audio_file', 'osu_file', 'folder_name', 'ranked', | |
'beatmap_id', 'set_id', 'ar', 'cs', 'hp', 'od', 'std_diffs', 'taiko_diffs', 'ctb_diffs', | |
'mania_diffs', 'time_drain', 'time_total', 'timingpoints', 'mode', 'loaded'] | |
@property | |
def std_rating(self): | |
return self.std_diffs[0] if len(self.std_diffs) else -1 | |
class TimingPoint(object): | |
__slots__ = ['offset', 'mpb', 'inherited'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Usage