This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Copyright (c) 2019-present NAVER Corp. | |
MIT License | |
""" | |
# -*- coding: utf-8 -*- | |
import numpy as np | |
import cv2 | |
import math |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load image as torch tensor | |
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] | |
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] | |
if cuda: | |
x = x.cuda() | |
# forward pass | |
with torch.no_grad(): | |
y, feature = net(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from skimage import io | |
import cv2 | |
def loadImage(img_file): | |
img = io.imread(img_file) # RGB order | |
if img.shape[0] == 2: img = img[0] | |
if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
if img.shape[2] == 4: img = img[:,:,:3] | |
img = np.array(img) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cuda = False # set to True for GPU | |
# load net | |
net = CRAFT() | |
# Load the weights from pre-trained model | |
if cuda: | |
net.load_state_dict(copyStateDict(torch.load(args.trained_model))) | |
else: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Copyright (c) 2019-present NAVER Corp. | |
MIT License | |
""" | |
# -*- coding: utf-8 -*- | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, sys | |
import numpy as np | |
import cv2 | |
import time | |
from imutils.object_detection import non_max_suppression | |
def east_detect(image): | |
layerNames = [ | |
"feature_fusion/Conv_7/Sigmoid", | |
"feature_fusion/concat_3"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import matplotlib.pyplot as plt | |
import cv2 | |
import pandas as pd | |
import time | |
model_path = "<model path>" | |
image_path = "<image_path>" | |
# Load the trained model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import copy | |
import time | |
from tqdm import tqdm | |
import torch | |
import numpy as np | |
import os | |
def train_model(model, criterion, dataloaders, optimizer, metrics, bpath, num_epochs=3): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchvision import models | |
from torchvision.models.segmentation.deeplabv3 import DeepLabHead | |
def createDeepLabv3(outputchannels=1): | |
model = models.segmentation.deeplabv3_resnet101( | |
pretrained=True, progress=True) | |
# Added a Tanh activation after the last convolution layer | |
model.classifier = DeepLabHead(2048, outputchannels) | |
# Set the model in training mode | |
model.train() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import Dataset, DataLoader | |
import glob | |
import os | |
import numpy as np | |
import cv2 | |
import torch | |
from torchvision import transforms, utils | |
from PIL import Image | |
NewerOlder