This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision | |
import pygame | |
import numpy as np | |
import gymnasium as gym | |
class MNISTSokoban(gym.Env): | |
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 20} | |
def __init__(self, map_file: str = None, size: tuple[int, int] = None, max_crates: int = 5, max_steps=200, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"editor.minimap.enabled": false, | |
"editor.fontFamily": "SF Mono", | |
"editor.bracketPairColorization.enabled": false, | |
"editor.acceptSuggestionOnEnter": "off", | |
"editor.inlineSuggest.enabled": true, | |
"files.autoSave": "onFocusChange", | |
"files.associations": { | |
"*.pro": "prolog", | |
"*.pl": "perl" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import matplotlib.pyplot as plt | |
def sample_gumbel_diff(*shape): | |
eps = 1e-20 | |
u1 = torch.rand(shape) | |
u2 = torch.rand(shape) | |
diff = torch.log(torch.log(u2+eps)/torch.log(u1+eps)+eps) | |
return diff |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from PIL import Image | |
def save_with_color(x, filename, n_per_row, pad_value=[0., 0., 0.], padding=2): | |
N, C, H, W = x.shape | |
n_per_col = N // n_per_row | |
if N > n_per_row * n_per_col: | |
n_per_col += 1 | |
canvas = torch.empty(H*n_per_col+(n_per_col-1)*padding, | |
W*n_per_row+(n_per_row-1)*padding, | |
3, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
def p2dist_pytorch(x, y): | |
y_dim = len(y.shape) | |
return torch.pow(x, 2).sum(dim=-1).view(x.shape[:-1]+(1,)) - \ | |
2 * torch.matmul(x, y.permute(list(range(y_dim-2))+[y_dim-1, y_dim-2])) + \ | |
torch.pow(y, 2).sum(dim=-1).view(y.shape[:-2]+(1,y.shape[-2])) | |
def p2dist_numpy(x, y): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class SoftNode(torch.nn.Module): | |
def __init__(self, in_features, out_features, depth, projection="constant"): | |
super(SoftNode, self).__init__() | |
self.projection = projection | |
if depth > 0: | |
self.left = SoftNode(in_features, out_features, depth-1, projection=projection) | |
self.right = SoftNode(in_features, out_features, depth-1, projection=projection) | |
self.gating = torch.nn.Linear(in_features, 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision | |
class VGGPerceptualLoss(torch.nn.Module): | |
def __init__(self, resize=True): | |
super(VGGPerceptualLoss, self).__init__() | |
blocks = [] | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# MIT License | |
# | |
# Copyright (c) 2024 Alper Ahmetoglu | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class SoftTree(torch.nn.Module): | |
def __init__(self, in_features, out_features, depth, projection='constant', dropout=0.0): | |
super(SoftTree, self).__init__() | |
self.proj = projection | |
self.depth = depth | |
self.in_features = in_features | |
self.out_features = out_features |