Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active September 1, 2024 20:28
Show Gist options
  • Save pszemraj/465b4d7e40d6e8f4178a428cc8c733fe to your computer and use it in GitHub Desktop.
Save pszemraj/465b4d7e40d6e8f4178a428cc8c733fe to your computer and use it in GitHub Desktop.
import logging
import os
import fire
import torch
from datasets import load_dataset
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class QualityModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super(QualityModel, self).__init__()
self.model = AutoModel.from_pretrained(config["base_model"])
self.dropout = nn.Dropout(config["fc_dropout"])
self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))
def forward(self, input_ids, attention_mask):
features = self.model(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
dropped = self.dropout(features)
outputs = self.fc(dropped)
return torch.softmax(outputs[:, 0, :], dim=1)
def get_workers() -> int:
"""get num cpus with safety factor"""
return int(os.cpu_count() // 2)
def get_device_type(model) -> str:
"""get the device type a transformers model is loaded on"""
device = str(model.device)
return device.split(":")[0]
def load_model(model_name="nvidia/quality-classifier-deberta", device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = QualityModel.from_pretrained(model_name).to(device)
model.eval()
return config, tokenizer, model, device
def classify_batch(batch, tokenizer, model, config, device, text_column):
inputs = tokenizer(
batch[text_column], return_tensors="pt", padding="longest", truncation=True
).to(device)
with torch.no_grad(), torch.autocast(get_device_type(model)):
outputs = model(inputs["input_ids"], inputs["attention_mask"])
predicted_classes = torch.argmax(outputs, dim=1)
predicted_labels = [
config.id2label[class_idx.item()] for class_idx in predicted_classes
]
batch["quality_prediction"] = predicted_labels
return batch
def main(
dataset_name: str,
text_column: str = "text",
model_name: str = "nvidia/quality-classifier-deberta",
batch_size: int = 32,
):
logger.info(f"Loading dataset: {dataset_name}")
dataset = load_dataset(dataset_name, num_proc=get_workers())
logger.info(f"Dataset loaded: {dataset}")
logger.info(f"Loading model: {model_name}")
config, tokenizer, model, device = load_model(model_name)
logger.info("Starting inference")
classified_dataset = dataset.map(
lambda batch: classify_batch(
batch, tokenizer, model, config, device, text_column
),
batched=True,
batch_size=batch_size,
desc="Classifying texts",
)
logger.info("Inference complete")
logger.info("Saving updated dataset")
classified_dataset.save_to_disk("quality_classified_dataset")
logger.info("Processing complete!")
return classified_dataset
if __name__ == "__main__":
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment