This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_train_data(path: str, sort: bool) -> List[Example]: | |
sentences = list() | |
with open(path) as f: | |
first = False | |
for line in f: | |
if not first: | |
first = True | |
continue | |
text_a, text_b, label = line.rstrip().split("\t") | |
lab = len(text_a) + len(text_b) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ... | |
def pad_seq(seq: List[int], max_batch_len: int, pad_value: int) -> List[int]: | |
# IRL, use pad_sequence | |
# https://pytorch.org/docs/master/generated/torch.nn.utils.rnn.pad_sequence.html | |
return seq + (max_batch_len - len(seq)) * [pad_value] | |
@dataclass | |
class SmartCollator(DataCollator): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# required by (\ SHELL COMMANDS \) | |
SHELL:=/bin/bash | |
VIRT_ENV_FOLDER = ~/.local/share/virtualenvs/xnli | |
SOURCE_VIRT_ENV = source $(VIRT_ENV_FOLDER)/bin/activate | |
.PHONY: train | |
train: | |
( \ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.pommedeterresautee.rxtest; | |
import android.content.Intent; | |
import android.os.Bundle; | |
import android.app.Activity; | |
import android.view.Menu; | |
import android.widget.TextView; | |
import rx.Observer; |