from pytorch_tools import torchfold

def encode_tree_fold(fold, tree):
    def encode_node(node):
        if node.is_leaf():
            return fold.add('leaf', node.id).split(2)
        else:
            left_h, left_c = encode_node(node.left)
            right_h, right_c = encode_node(node.right)
            return fold.add('children', left_h, left_c, right_h, right_c).split(2)
    encoding, _ = encode_node(tree.root)
    return fold.add('logits', encoding)

...

fold = torchfold.Fold(cuda=args.cuda)

all_logits, all_labels = [], []
for tree in batch:
    all_logits.append(encode_tree_folded(fold, tree))
    all_labels.append(tree.label)

res = fold.apply(model, [all_logits, all_labels])
loss = criterion(res[0], res[1])