Last active
July 7, 2022 17:20
-
-
Save davidefiocco/1c9e437de7b31e81bf2b8fecbe1d63ed to your computer and use it in GitHub Desktop.
Example prodigy recipe to use a zero-shot classifier to pre-classify examples when performing labeling for text classification (see https://support.prodi.gy/t/can-one-leverage-zero-shot-classifiers-for-textcat-tasks/4885)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"text":"Spam spam lovely spam!"} | |
{"text":"I like scrambled eggs."} | |
{"text":"I prefer spam!"} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Usage: | |
# python -m prodigy textcat.zero-shot -F .\textcat_zero_shot.py my_dataset dataset.jsonl facebook/bart-large-mnli --label SPAM,EGGS | |
from typing import Iterable, List | |
import prodigy | |
from prodigy.components.loaders import JSONL | |
from prodigy.components.sorters import prefer_high_scores | |
from prodigy.util import split_string | |
from tqdm import tqdm | |
from transformers import pipeline | |
class ZeroShotClassifier(object): | |
def __init__(self, labels: List[str], model: str): | |
self.pipeline = pipeline("zero-shot-classification", model=model) | |
self.labels = labels | |
def __call__(self, stream: Iterable[dict]): | |
for eg in tqdm(stream): | |
result = self.pipeline(eg["text"], self.labels) | |
eg["label"] = result["labels"][0] | |
score = result["scores"][0] | |
# format score to have it visualized in the UI | |
eg["meta"] = {"score": f"{score:.3f}"} | |
yield (score, eg) | |
@prodigy.recipe( | |
"textcat.zero-shot", | |
dataset=("The dataset to use", "positional", None, str), | |
source=("The source data as a JSONL file", "positional", None, str), | |
model=("Model name (from the Huggingface hub)", "positional", None, str), | |
label=("One or more comma-separated labels", "option", "l", split_string), | |
) | |
def textcat_zero_shot(dataset: str, source: str, label: List[str], model: str): | |
# Load the stream from a JSONL file and return a generator that yields a | |
# dictionary for each example in the data. | |
stream = JSONL(source) | |
# Load the zero-shot classification model and run it with labels | |
model = ZeroShotClassifier(labels=label, model=model) | |
stream = prefer_high_scores(model(stream)) | |
return { | |
"view_id": "classification", # Annotation interface to use | |
"dataset": dataset, # Name of dataset to save annotations | |
"stream": stream, # Incoming stream of examples | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment