Created
April 12, 2022 10:34
-
-
Save nrupatunga/cb6d8546c2903849d1f1b1dd6445ac25 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Required packages: torch | |
# Desired packages: numpy, matplotlib | |
# Add missing imports: | |
def main(): | |
training_images = {x for x in range(1, 100001)} | |
priority_images = {x * 100 for x in range(1, 11)} | |
# Task: | |
# 1. Dataloader needs to return a batch of size 10. Each element is a unique element from `training_images`. | |
# 2. Elements from `priority_images` must be present in EVERY batch with the ratio 1:1, meaning that batch will have | |
# 50% of images from `priority_images` and 50% from `training_images` containers. | |
# NOTE: elements of `priority_images` are part of the `training_images`. | |
# 3. The code needs to be multiprocessing safe. | |
# 4. Follow comments below to add extra functionality. | |
batch_size = 10 | |
num_epochs = 10 | |
for epoch in range(num_epochs): | |
pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment