Last active
September 1, 2024 21:00
-
-
Save pszemraj/5f0e3fdb6cc530d6108cb64207bec999 to your computer and use it in GitHub Desktop.
inference with nvidia's domain classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import os | |
import fire | |
import torch | |
from datasets import load_dataset | |
from huggingface_hub import PyTorchModelHubMixin | |
from torch import nn | |
from transformers import AutoConfig, AutoModel, AutoTokenizer | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
def check_ampere_gpu(): | |
""" | |
Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does. | |
""" | |
# Check if CUDA is available | |
if not torch.cuda.is_available(): | |
print("No GPU detected, running on CPU.") | |
return | |
try: | |
# Get the compute capability of the GPU | |
device = torch.cuda.current_device() | |
capability = torch.cuda.get_device_capability(device) | |
major, minor = capability | |
# Check if the GPU is Ampere or newer (compute capability >= 8.0) | |
if major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
gpu_name = torch.cuda.get_device_name(device) | |
print( | |
f"{gpu_name} (compute capability {major}.{minor}) supports NVIDIA Ampere or later, " | |
"enabled TF32 in PyTorch." | |
) | |
else: | |
gpu_name = torch.cuda.get_device_name(device) | |
print( | |
f"{gpu_name} (compute capability {major}.{minor}) does not support NVIDIA Ampere or later." | |
) | |
except Exception as e: | |
print(f"Error occurred while checking GPU: {e}") | |
class DomainModel(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, config): | |
super(DomainModel, self).__init__() | |
self.model = AutoModel.from_pretrained(config["base_model"]) | |
self.dropout = nn.Dropout(config["fc_dropout"]) | |
self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"])) | |
def forward(self, input_ids, attention_mask): | |
features = self.model( | |
input_ids=input_ids, attention_mask=attention_mask | |
).last_hidden_state | |
dropped = self.dropout(features) | |
outputs = self.fc(dropped) | |
return torch.softmax(outputs[:, 0, :], dim=1) | |
def get_workers(): | |
return int(os.cpu_count() // 2) | |
def get_device_type(model): | |
device = str(model.device) | |
return device.split(":")[0] | |
def load_model(model_name="nvidia/domain-classifier", device=None): | |
if device is None: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
config = AutoConfig.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = DomainModel.from_pretrained(model_name).to(device) | |
model.eval() | |
return config, tokenizer, model, device | |
def classify_batch(batch, tokenizer, model, config, device, text_column): | |
inputs = tokenizer( | |
batch[text_column], return_tensors="pt", padding="longest", truncation=True | |
).to(device) | |
with torch.no_grad(), torch.autocast(get_device_type(model.model)): | |
outputs = model(inputs["input_ids"], inputs["attention_mask"]) | |
predicted_classes = torch.argmax(outputs, dim=1) | |
predicted_labels = [ | |
config.id2label[class_idx.item()] for class_idx in predicted_classes | |
] | |
batch["domain_prediction"] = predicted_labels | |
return batch | |
def main( | |
dataset_name: str, | |
text_column: str = "text", | |
model_name: str = "nvidia/domain-classifier", | |
batch_size: int = 32, | |
): | |
logger.info(f"Loading dataset: {dataset_name}") | |
dataset = load_dataset(dataset_name, num_proc=get_workers()) | |
logger.info(f"Dataset loaded: {dataset}") | |
check_ampere_gpu() | |
logger.info(f"Loading model: {model_name}") | |
config, tokenizer, model, device = load_model(model_name) | |
logger.info("Starting inference") | |
classified_dataset = dataset.map( | |
lambda batch: classify_batch( | |
batch, tokenizer, model, config, device, text_column | |
), | |
batched=True, | |
batch_size=batch_size, | |
desc="Classifying texts", | |
) | |
logger.info("Inference complete") | |
logger.info("Saving updated dataset") | |
classified_dataset.save_to_disk("domain_classified_dataset") | |
logger.info("Processing complete!") | |
return classified_dataset | |
if __name__ == "__main__": | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment