Created
March 7, 2022 06:34
-
-
Save numb3r3/d80515f68f8d0503766959ea26af8a70 to your computer and use it in GitHub Desktop.
prepare-dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import re | |
from pathlib import Path | |
import click | |
from docarray import Document, DocumentArray | |
BLACK_TOKENS = set(['test', 'prefab', 'background']) | |
def strip_token(token): | |
token = re.sub("\d+", " ", token) | |
return token.strip() | |
# return ''.join(filter(lambda x: x.isalpha(), token)) | |
def is_valid_token(token): | |
if not token: return False | |
if len(token) <= 2: return False | |
if token.startswith('[') and token.endswith(']'): return False | |
if token in BLACK_TOKENS: return False | |
count = 0 | |
for t in token: | |
if t.isalpha() and t.islower(): | |
count += 1 | |
return count / len(token) > 0.65 | |
def split_on_uppercase(s, seperators=['.', '-', '_', ' '], keep_contiguous: bool = False): | |
string_length = len(s) | |
is_lower_around = (lambda: s[i - 1].islower() or | |
string_length > (i + 1) and s[i + 1].islower()) | |
start = 0 | |
parts = [] | |
for i in range(1, string_length): | |
if (s[i] in seperators) or (s[i].isupper() and (not keep_contiguous or is_lower_around())): | |
parts.append(s[start: i]) | |
if s[i] in seperators: | |
i += 1 | |
start = i | |
parts.append(s[start:]) | |
return [t for t in parts if t] | |
def get_tokens(s): | |
tokens = split_on_uppercase(s, keep_contiguous=True) | |
tokens = [strip_token(t) for t in tokens] | |
tokens = [t for t in tokens if is_valid_token(t)] | |
if len(tokens) <= 2: | |
max_token_len = max([len(t) for t in tokens] + [0]) | |
if max_token_len < 3: | |
return None | |
return tokens | |
def norm_tokens(tokens): | |
text = ' '.join([t.lower() for t in tokens]) | |
return text | |
@click.command() | |
@click.option('-i', '--input_path', help='the input JSON-data path') | |
@click.option('-o', '--output_path', help='the output da path') | |
def main(input_path, output_path): | |
da = DocumentArray() | |
for fn in Path(input_path).glob('**/*.json'): | |
data = json.load(fn.open()) | |
category = data['category']['slug'] | |
package_path = data['extendedProperties']['packagePath'] | |
images = data['images']['default']['featured'] | |
package_file = ' '.join(package_path.split('/')[-1].split('.')[:-1]) | |
tokens = get_tokens(package_file) | |
if tokens: | |
caption = norm_tokens(tokens) | |
doc = Document(uri=package_path, tags={'name': data['name'], 'category': category, 'caption': caption, | |
**data['extendedProperties']}) | |
for img in images: | |
img_doc = Document(uri=img['href']) | |
img_doc.load_uri_to_image_tensor().set_image_tensor_shape((256, 256)) | |
doc.chunks.append(img_doc) | |
da.append(doc) | |
da.save_binary(output_path) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Run