Last active
May 20, 2021 10:40
-
-
Save magesh-technovator/934df1ddd9a32b041d1815bd1faca17e to your computer and use it in GitHub Desktop.
Load DeepLabV3 model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchvision import models | |
from torchvision.models.segmentation.deeplabv3 import DeepLabHead | |
def createDeepLabv3(outputchannels=1): | |
model = models.segmentation.deeplabv3_resnet101( | |
pretrained=True, progress=True) | |
# Added a Tanh activation after the last convolution layer | |
model.classifier = DeepLabHead(2048, outputchannels) | |
# Set the model in training mode | |
model.train() | |
return model | |
model = createDeepLabv3(3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment