Skip to content

Instantly share code, notes, and snippets.

@dingusagar
Last active July 10, 2022 09:04
Show Gist options
  • Save dingusagar/6a56c852f2cbb8294d5c203ca3a81551 to your computer and use it in GitHub Desktop.
Save dingusagar/6a56c852f2cbb8294d5c203ca3a81551 to your computer and use it in GitHub Desktop.
from transformers import ViTFeatureExtractor, ViTForImageClassification, Trainer, TrainingArguments
import torch
from datasets import load_dataset, load_metric
import numpy as np
DATASET_DIR = '/content/output'
dataset = load_dataset(name="avengers", path=DATASET_DIR, data_files={"train": "/content/output/train/**", "test": "/content/output/val/**"})
labels = dataset['train'].features['label'].names
def transform(example_batch):
# Take a list of PIL images and turn them to pixel values
inputs = feature_extractor([x.convert("RGB") for x in example_batch['image']], return_tensors='pt')
# Don't forget to include the labels!
inputs['labels'] = example_batch['label']
return inputs
prepared_ds = dataset.with_transform(transform)
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['labels'] for x in batch])
}
metric = load_metric("accuracy")
def compute_metrics(p):
return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
model = ViTForImageClassification.from_pretrained(
model_name_or_path,
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)}
)
training_args = TrainingArguments(
output_dir="./vit-base-avengers-v1",
per_device_train_batch_siz16,
evaluation_strategy="steps",
num_train_epochs=4,
fp16=True,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='tensorboard',
load_best_model_at_end=True)
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_ds["train"],
eval_dataset=prepared_ds["test"],
tokenizer=feature_extractor,
)
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
metrics = trainer.evaluate(prepared_ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment