from pytorch_tools import torchfold def encode_tree_fold(fold, tree): def encode_node(node): if node.is_leaf(): return fold.add('leaf', node.id).split(2) else: left_h, left_c = encode_node(node.left) right_h, right_c = encode_node(node.right) return fold.add('children', left_h, left_c, right_h, right_c).split(2) encoding, _ = encode_node(tree.root) return fold.add('logits', encoding) ... fold = torchfold.Fold(cuda=args.cuda) all_logits, all_labels = [], [] for tree in batch: all_logits.append(encode_tree_folded(fold, tree)) all_labels.append(tree.label) res = fold.apply(model, [all_logits, all_labels]) loss = criterion(res[0], res[1])