Skip to content

Instantly share code, notes, and snippets.

@airicbear
Created February 19, 2023 18:35
Show Gist options
  • Save airicbear/0a923f12ae4d1187095616f294550ad3 to your computer and use it in GitHub Desktop.
Save airicbear/0a923f12ae4d1187095616f294550ad3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{"cells":[{"cell_type":"code","execution_count":43,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.702897Z","iopub.status.busy":"2023-02-18T22:25:29.702552Z","iopub.status.idle":"2023-02-18T22:25:29.711940Z","shell.execute_reply":"2023-02-18T22:25:29.711234Z","shell.execute_reply.started":"2023-02-18T22:25:29.702844Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","\n","USING_KAGGLE = False\n","AUTOTUNE = tf.data.experimental.AUTOTUNE\n","\n","try:\n"," tpu = tf.distribute.cluster_resolver.TPUClusterResolver()\n"," tf.config.experimental_connect_to_cluster(tpu)\n"," tf.tpu.experimental.initialize_tpu_system(tpu)\n"," strategy = tf.distribute.experimental.TPUStrategy(tpu)\n","except:\n"," strategy = tf.distribute.get_strategy()"]},{"cell_type":"code","execution_count":44,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.715487Z","iopub.status.busy":"2023-02-18T22:25:29.714811Z","iopub.status.idle":"2023-02-18T22:25:29.724322Z","shell.execute_reply":"2023-02-18T22:25:29.723533Z","shell.execute_reply.started":"2023-02-18T22:25:29.715443Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","\n","def decode_image(image: tf.io.FixedLenFeature, image_size: 'list[int]' = [256, 256]) -> tf.Tensor:\n"," image = tf.image.decode_jpeg(image, channels=3)\n"," image = (tf.cast(image, tf.float32) / 127.5) - 1\n"," image = tf.reshape(image, [*image_size, 3])\n"," return image\n","\n","def read_tfrecord(example: tf.Tensor) -> tf.Tensor:\n"," tfrecord_format = {\n"," 'image': tf.io.FixedLenFeature([], tf.string),\n"," }\n"," example = tf.io.parse_single_example(example, tfrecord_format)\n"," image = decode_image(example['image'])\n"," return image\n","\n","def read_tfrecords(filenames: 'list[str]') -> tf.data.TFRecordDataset:\n"," dataset = tf.data.TFRecordDataset(filenames)\n"," dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)\n"," return dataset\n"]},{"cell_type":"code","execution_count":45,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.726118Z","iopub.status.busy":"2023-02-18T22:25:29.725782Z","iopub.status.idle":"2023-02-18T22:25:29.740727Z","shell.execute_reply":"2023-02-18T22:25:29.739957Z","shell.execute_reply.started":"2023-02-18T22:25:29.726086Z"},"trusted":true},"outputs":[],"source":["from typing import Callable\n","\n","import tensorflow as tf\n","import numpy as np\n","import re\n","\n","def dataset_path(path: str = '../Sample_Data/Raw') -> str:\n"," \"\"\"This function loads the Monet and photo images in `TFRecordDataset` format.\"\"\"\n"," if USING_KAGGLE:\n"," from kaggle_datasets import KaggleDatasets\n"," return KaggleDatasets().get_gcs_path()\n"," else:\n"," return path\n","\n","\n","def monet_filenames():\n"," \"\"\"This sub-function returns the filenames for the monet painting\"\"\"\n"," return tf.io.gfile.glob(f'{dataset_path()}/monet_tfrec/*.tfrec')\n","\n","\n","def photo_filenames():\n"," \"\"\"This sub-function returns the filenames for the photos\"\"\"\n"," return tf.io.gfile.glob(f'{dataset_path()}/photo_tfrec/*.tfrec')\n","\n","\n","def monet_dataset():\n"," \"\"\"This sub-function loads the Monet paintings in `TFRecordDataset` format\"\"\"\n"," return read_tfrecords(monet_filenames())\n","\n","\n","def photo_dataset():\n"," \"\"\"This sub-function loads the photos in `TFRecordDataset` format\"\"\"\n"," return read_tfrecords(photo_filenames())\n","\n","\n","def count_data_items(filenames) -> np.array:\n"," n: list\n","\n"," n = [int(re.compile(r\"-([0-9]*)\\.\").search(filename).group(1)) for filename in filenames]\n","\n"," return np.sum(n)\n","\n","\n","def count_monet_samples():\n"," return count_data_items(monet_filenames())\n","\n","\n","def count_photo_samples():\n"," return count_data_items(photo_filenames())\n","\n","\n","def get_gan_dataset(augment: Callable[[tf.Tensor], tf.Tensor] = None,\n"," repeat: bool = True,\n"," shuffle: bool = True,\n"," batch_size: int = 1):\n"," monets = monet_dataset()\n"," photos = photo_dataset()\n","\n"," if repeat:\n"," monets = monets.repeat()\n"," photos = photos.repeat()\n","\n"," if shuffle:\n"," monets = monets.shuffle(2048)\n"," photos = photos.shuffle(2048)\n","\n"," monets = monets.batch(batch_size, drop_remainder=True)\n"," photos = photos.batch(batch_size, drop_remainder=True)\n","\n"," if augment:\n"," monets = monets.map(augment, num_parallel_calls=AUTOTUNE)\n"," photos = photos.map(augment, num_parallel_calls=AUTOTUNE)\n","\n"," monets = monets.prefetch(AUTOTUNE)\n"," photos = photos.prefetch(AUTOTUNE)\n","\n"," gan_dataset = tf.data.Dataset.zip((monets, photos))\n","\n"," return gan_dataset\n"]},{"cell_type":"code","execution_count":46,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.742325Z","iopub.status.busy":"2023-02-18T22:25:29.742012Z","iopub.status.idle":"2023-02-18T22:25:29.754563Z","shell.execute_reply":"2023-02-18T22:25:29.753899Z","shell.execute_reply.started":"2023-02-18T22:25:29.742294Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","from tensorflow_addons.layers import InstanceNormalization\n","\n","def downsample(filters: int, size: int, apply_instancenorm: bool = True) -> tf.keras.Sequential:\n"," initializer = tf.random_normal_initializer(0., 0.02)\n"," gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)\n"," \n"," result = tf.keras.Sequential()\n"," result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))\n"," if apply_instancenorm:\n"," result.add(InstanceNormalization(gamma_initializer=gamma_init))\n"," result.add(tf.keras.layers.ReLU())\n","\n"," return result\n","\n","def upsample(filters: int, size: int, apply_dropout: bool = False) -> tf.keras.Sequential:\n"," initializer = tf.random_normal_initializer(0., 0.02)\n"," gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)\n","\n"," result = tf.keras.Sequential()\n"," result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))\n"," result.add(InstanceNormalization(gamma_initializer=gamma_init))\n"," if apply_dropout:\n"," result.add(tf.keras.layers.Dropout(0.5))\n"," result.add(tf.keras.layers.ReLU())\n","\n"," return result"]},{"cell_type":"code","execution_count":47,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.756689Z","iopub.status.busy":"2023-02-18T22:25:29.756423Z","iopub.status.idle":"2023-02-18T22:25:29.768980Z","shell.execute_reply":"2023-02-18T22:25:29.768248Z","shell.execute_reply.started":"2023-02-18T22:25:29.756658Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","from typing import List\n","\n","def Generator(output_channels: int = 3) -> tf.keras.Model:\n"," \"\"\"\n"," The Generator applies an encoder-decoder architecture that downsamples the image and then decodes the image\n"," through various transpose convolutions.\" It is inspired by a \"U-Net\" architecture for image generation, which downsamples\n"," an image and then applies \n"," \"\"\"\n"," down_stack = [\n"," downsample(64, 4, apply_instancenorm=False),\n"," downsample(128, 4),\n"," downsample(256, 4),\n"," downsample(512, 4),\n"," downsample(512, 4),\n"," downsample(512, 4),\n"," downsample(512, 4),\n"," downsample(512, 4),\n"," ]\n","\n"," up_stack = [\n"," upsample(512, 4, apply_dropout=True),\n"," upsample(512, 4, apply_dropout=True),\n"," upsample(512, 4, apply_dropout=True),\n"," upsample(512, 4),\n"," upsample(256, 4),\n"," upsample(128, 4),\n"," upsample(64, 4),\n"," ]\n","\n"," initializer = tf.random_normal_initializer(0., 0.02)\n","\n"," inputs = tf.keras.layers.Input(shape=[256,256,3])\n"," x = inputs\n","\n"," skips: List[tf.keras.Sequential] = []\n"," for down in down_stack:\n"," x = down(x)\n"," skips.append(x)\n","\n"," skips = list(reversed(skips[:-1]))\n","\n"," for up, skip in zip(up_stack, skips):\n"," x = up(x)\n"," x = tf.keras.layers.Concatenate()([x, skip])\n","\n"," x = tf.keras.layers.Conv2DTranspose(output_channels,\n"," 4,\n"," strides=2,\n"," padding='same',\n"," kernel_initializer=initializer,\n"," activation='tanh')(x)\n","\n"," return tf.keras.Model(inputs=inputs, outputs=x)\n"]},{"cell_type":"code","execution_count":48,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.770875Z","iopub.status.busy":"2023-02-18T22:25:29.770111Z","iopub.status.idle":"2023-02-18T22:25:29.782620Z","shell.execute_reply":"2023-02-18T22:25:29.781880Z","shell.execute_reply.started":"2023-02-18T22:25:29.770841Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","from tensorflow_addons.layers import InstanceNormalization\n","\n","def Discriminator() -> tf.keras.Model:\n"," initializer = tf.random_normal_initializer(0., 0.02)\n"," gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)\n","\n"," inputs = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')\n","\n"," x = inputs\n"," down1 = downsample(64, 4, False)(x)\n"," down2 = downsample(128, 4)(down1)\n"," down3 = downsample(256, 4)(down2)\n"," zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)\n"," conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1)\n"," norm1 = InstanceNormalization(gamma_initializer=gamma_init)(conv)\n"," leaky_relu = tf.keras.layers.LeakyReLU()(norm1)\n"," zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)\n"," outputs = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2)\n","\n"," return tf.keras.Model(inputs=inputs, outputs=outputs)\n"]},{"cell_type":"code","execution_count":49,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.786164Z","iopub.status.busy":"2023-02-18T22:25:29.785957Z","iopub.status.idle":"2023-02-18T22:25:29.817426Z","shell.execute_reply":"2023-02-18T22:25:29.816647Z","shell.execute_reply.started":"2023-02-18T22:25:29.786140Z"},"trusted":true},"outputs":[],"source":["from typing import Callable, Dict, List\n","\n","import tensorflow as tf\n","\n","with strategy.scope():\n"," def diff_augment(x: tf.Tensor,\n"," policy: str = '',\n"," channels_first: bool = False) -> tf.Tensor:\n"," if policy:\n"," if channels_first:\n"," x = tf.transpose(x, [0, 2, 3, 1])\n"," for p in policy.split(','):\n"," for fn in AUGMENT_FNS[p]:\n"," x = fn(x)\n"," if channels_first:\n"," x = tf.transpose(x, [0, 3, 1, 2])\n"," return x\n","\n","\n"," def rand_brightness(x: tf.Tensor) -> tf.Tensor:\n"," magnitude: tf.Tensor\n","\n"," magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) - 0.5\n"," x = x + magnitude\n","\n"," return x\n","\n","\n"," def rand_saturation(x: tf.Tensor) -> tf.Tensor:\n"," magnitude: tf.Tensor\n"," x_mean: tf.Tensor\n","\n"," magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) * 2\n"," x_mean = tf.reduce_mean(x, axis=3, keepdims=True) * 0.3333333333333333333\n"," x = (x - x_mean) * magnitude + x_mean\n","\n"," return x\n","\n","\n"," def rand_contrast(x: tf.Tensor) -> tf.Tensor:\n"," magnitude: tf.Tensor\n"," x_mean: tf.Tensor\n","\n"," magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) + 0.5\n"," x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True) * 5.086e-6\n"," x = (x - x_mean) * magnitude + x_mean\n","\n"," return x\n","\n","\n"," def rand_translation(x: tf.Tensor,\n"," ratio: float = 0.125) -> tf.Tensor:\n"," batch_size: tf.Tensor\n"," image_size: tf.Tensor\n"," shift: tf.Tensor\n"," translation_x: tf.Tensor\n"," translation_y: tf.Tensor\n"," grid_x: tf.Tensor\n"," grid_y: tf.Tensor\n","\n"," batch_size = tf.shape(x)[0]\n"," image_size = tf.shape(x)[1:3]\n"," shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)\n"," translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32)\n"," translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32)\n"," grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1,\n"," 0,\n"," image_size[0] + 1)\n"," grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1,\n"," 0,\n"," image_size[1] + 1)\n"," x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1)\n"," x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]),\n"," tf.expand_dims(grid_y, -1),\n"," batch_dims=1),\n"," [0, 2, 1, 3])\n","\n"," return x\n","\n","\n"," def rand_cutout(x: tf.Tensor,\n"," ratio: float = 0.5) -> tf.Tensor:\n"," batch_size: tf.Tensor\n"," image_size: tf.Tensor\n"," cutout_size: tf.Tensor\n"," offset_x: tf.Tensor\n"," offset_y: tf.Tensor\n"," grid_batch: tf.Tensor\n"," grid_x: tf.Tensor\n"," grid_y: tf.Tensor\n"," cutout_grid: tf.Tensor\n"," mask: tf.Tensor\n","\n"," batch_size = tf.shape(x)[0]\n"," image_size = tf.shape(x)[1:3]\n"," cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)\n"," offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1],\n"," maxval=image_size[0] + (1 - cutout_size[0] % 2),\n"," dtype=tf.int32)\n"," offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1],\n"," maxval=image_size[1] + (1 - cutout_size[1] % 2),\n"," dtype=tf.int32)\n"," grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32),\n"," tf.range(cutout_size[0], dtype=tf.int32),\n"," tf.range(cutout_size[1], dtype=tf.int32),\n"," indexing='ij')\n"," cutout_grid = tf.stack([grid_batch,\n"," grid_x + offset_x - cutout_size[0] // 2,\n"," grid_y + offset_y - cutout_size[1] // 2],\n"," axis=-1)\n"," mask_shape = tf.stack([batch_size, image_size[0], image_size[1]])\n"," cutout_grid = tf.maximum(cutout_grid, 0)\n"," cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3]))\n"," mask = tf.maximum(1 - tf.scatter_nd(cutout_grid,\n"," tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32),\n"," mask_shape),\n"," 0)\n"," x = x * tf.expand_dims(mask, axis=3)\n","\n"," return x\n","\n","\n"," AUGMENT_FNS: Dict[str, List[Callable[[tf.Tensor], tf.Tensor]]] = {\n"," 'color': [rand_brightness, rand_saturation, rand_contrast],\n"," 'translation': [rand_translation],\n"," 'cutout': [rand_cutout],\n"," }\n","\n","\n"," def aug_fn(image: tf.Tensor) -> tf.Tensor:\n"," return diff_augment(image,\n"," \"color,translation,cutout\")\n","\n","\n","def data_augment_flip(image: tf.Tensor) -> tf.Tensor:\n"," image = tf.image.random_flip_left_right(image)\n"," return image\n"]},{"cell_type":"code","execution_count":50,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.819262Z","iopub.status.busy":"2023-02-18T22:25:29.818731Z","iopub.status.idle":"2023-02-18T22:25:29.839088Z","shell.execute_reply":"2023-02-18T22:25:29.838394Z","shell.execute_reply.started":"2023-02-18T22:25:29.819227Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","from typing import Callable\n","\n","\n","class CycleGan(tf.keras.Model):\n"," \"\"\"\n"," This class is the CylceGAN model, which initializes and compiles the generators and discriminators\n"," \"\"\"\n"," \n"," def __init__(self,\n"," monet_generator: tf.keras.Model,\n"," photo_generator: tf.keras.Model,\n"," monet_discriminator: tf.keras.Model,\n"," photo_discriminator: tf.keras.Model,\n"," lambda_cycle: int = 3,\n"," lambda_id: int = 3,\n"," ):\n"," \"\"\"\n"," This is the initialization function for the generators, discriminators, and the lambda cycle\n"," \"\"\"\n"," super(CycleGan, self).__init__()\n"," self.monet_generator = monet_generator\n"," self.photo_generator = photo_generator\n"," self.monet_discriminator = monet_discriminator\n"," self.photo_discriminator = photo_discriminator\n"," self.lambda_cycle = lambda_cycle\n"," self.lambda_id = lambda_id\n"," \n"," def compile(self,\n"," monet_generator_optimizer: tf.keras.optimizers.Optimizer,\n"," photo_generator_optimizer: tf.keras.optimizers.Optimizer,\n"," monet_discriminator_optimizer: tf.keras.optimizers.Optimizer,\n"," photo_discriminator_optimizer: tf.keras.optimizers.Optimizer,\n"," generator_loss_fn: Callable[[tf.keras.Model], tf.Tensor],\n"," discriminator_loss_fn: Callable[[tf.keras.Model, tf.keras.Model], tf.Tensor],\n"," cycle_loss_fn: Callable[[tf.Tensor, tf.Tensor, float], tf.Tensor],\n"," identity_loss_fn: Callable[[tf.Tensor, tf.Tensor, float], tf.Tensor],\n"," aug_fn):\n"," \"\"\"\n"," This function sets the optimizers and the loss functions\n"," \"\"\"\n"," super(CycleGan, self).compile()\n"," self.monet_generator_optimizer = monet_generator_optimizer\n"," self.photo_generator_optimizer = photo_generator_optimizer\n"," self.monet_discriminator_optimizer = monet_discriminator_optimizer\n"," self.photo_discriminator_optimizer = photo_discriminator_optimizer\n"," self.generator_loss_fn = generator_loss_fn\n"," self.discriminator_loss_fn = discriminator_loss_fn\n"," self.cycle_loss_fn = cycle_loss_fn\n"," self.identity_loss_fn = identity_loss_fn\n"," self.aug_fn = aug_fn\n"," \n","\n","\n"," @tf.function\n"," def train_step(self, batch_data: 'tuple[tf.Tensor, tf.Tensor]'):\n"," real_monet, real_photo = batch_data\n"," batch_size = tf.shape(real_monet)[0]\n","\n"," with tf.GradientTape(persistent=True) as tape:\n"," fake_monet = self.monet_generator(real_photo, training=True)\n"," cycled_photo = self.photo_generator(fake_monet, training=True)\n","\n"," fake_photo = self.photo_generator(real_monet, training=True)\n"," cycled_monet = self.monet_generator(fake_photo, training=True)\n","\n"," same_monet = self.monet_generator(real_monet, training=True)\n"," same_photo = self.photo_generator(real_photo, training=True)\n","\n"," both_monet = tf.concat([real_monet, fake_monet], axis=0)\n","\n"," aug_monet = self.aug_fn(both_monet)\n","\n"," aug_real_monet = aug_monet[:batch_size]\n"," aug_fake_monet = aug_monet[batch_size:]\n","\n"," discriminator_real_monet = self.monet_discriminator(aug_real_monet, training=True)\n"," discriminator_real_photo = self.photo_discriminator(real_photo, training=True)\n","\n"," discriminator_fake_monet = self.monet_discriminator(aug_fake_monet, training=True)\n"," discriminator_fake_photo = self.photo_discriminator(fake_photo, training=True)\n","\n"," monet_generator_loss = self.generator_loss_fn(discriminator_fake_monet)\n"," photo_generator_loss = self.generator_loss_fn(discriminator_fake_photo)\n","\n"," total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle / tf.cast(batch_size, tf.float32)) \\\n"," + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle / tf.cast(batch_size, tf.float32))\n","\n"," total_monet_generator_loss = monet_generator_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_id / tf.cast(batch_size, tf.float32))\n"," total_photo_generator_loss = photo_generator_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_id / tf.cast(batch_size, tf.float32))\n","\n"," monet_discriminator_loss = self.discriminator_loss_fn(discriminator_real_monet, discriminator_fake_monet)\n"," photo_discriminator_loss = self.discriminator_loss_fn(discriminator_real_photo, discriminator_fake_photo)\n","\n"," monet_generator_gradients = tape.gradient(total_monet_generator_loss, self.monet_generator.trainable_variables)\n"," photo_generator_gradients = tape.gradient(total_photo_generator_loss, self.photo_generator.trainable_variables)\n"," \n"," monet_discriminator_gradients = tape.gradient(monet_discriminator_loss, self.monet_discriminator.trainable_variables)\n"," photo_discriminator_gradients = tape.gradient(photo_discriminator_loss, self.photo_discriminator.trainable_variables)\n","\n"," self.monet_generator_optimizer.apply_gradients(zip(monet_generator_gradients, self.monet_generator.trainable_variables))\n"," self.photo_generator_optimizer.apply_gradients(zip(photo_generator_gradients, self.photo_generator.trainable_variables))\n","\n"," self.monet_discriminator_optimizer.apply_gradients(zip(monet_discriminator_gradients, self.monet_discriminator.trainable_variables))\n"," self.photo_discriminator_optimizer.apply_gradients(zip(photo_discriminator_gradients, self.photo_discriminator.trainable_variables))\n","\n"," return {\n"," 'monet_generator_loss': total_monet_generator_loss,\n"," 'photo_generator_loss': total_photo_generator_loss,\n"," 'monet_discriminator_loss': monet_discriminator_loss,\n"," 'photo_discriminator_loss': photo_discriminator_loss\n"," }"]},{"cell_type":"code","execution_count":51,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.841110Z","iopub.status.busy":"2023-02-18T22:25:29.840495Z","iopub.status.idle":"2023-02-18T22:25:29.852611Z","shell.execute_reply":"2023-02-18T22:25:29.851937Z","shell.execute_reply.started":"2023-02-18T22:25:29.841070Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","\n","with strategy.scope():\n"," def discriminator_loss(real: tf.keras.Model, generated: tf.keras.Model) -> tf.Tensor:\n"," real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True,\n"," reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real),\n"," real)\n"," generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True,\n"," reduction=tf.keras.losses.Reduction.NONE)(\n"," tf.zeros_like(generated), generated)\n"," total_discriminator_loss = real_loss + generated_loss\n"," return total_discriminator_loss * 0.5\n","\n","\n"," def generator_loss(generated: tf.keras.Model) -> tf.Tensor:\n"," return tf.keras.losses.BinaryCrossentropy(from_logits=True,\n"," reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated),\n"," generated)\n","\n","\n"," def calc_cycle_loss(real_image: tf.Tensor, cycled_image: tf.Tensor, alpha: float) -> tf.Tensor:\n"," loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))\n"," return alpha * loss1\n","\n","\n"," def identity_loss(real_image: tf.Tensor, same_image: tf.Tensor, alpha: float) -> tf.Tensor:\n"," loss = tf.reduce_mean(tf.abs(real_image - same_image))\n"," return alpha * 0.5 * loss\n"]},{"cell_type":"code","execution_count":52,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.854252Z","iopub.status.busy":"2023-02-18T22:25:29.853971Z","iopub.status.idle":"2023-02-18T22:25:31.914875Z","shell.execute_reply":"2023-02-18T22:25:31.914116Z","shell.execute_reply.started":"2023-02-18T22:25:29.854207Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","\n","monet_generator = Generator()\n","photo_generator = Generator()\n","monet_discriminator = Discriminator()\n","photo_discriminator = Discriminator()\n","\n","cycle_gan_model = CycleGan(monet_generator, photo_generator, monet_discriminator, photo_discriminator)\n","\n","\n","def cycle_gan_compile(monet_generator_optimizer: tf.keras.optimizers.Optimizer,\n"," photo_generator_optimizer: tf.keras.optimizers.Optimizer,\n"," monet_discriminator_optimizer: tf.keras.optimizers.Optimizer,\n"," photo_discriminator_optimizer: tf.keras.optimizers.Optimizer) -> None:\n"," cycle_gan_model.compile(\n"," monet_generator_optimizer=monet_generator_optimizer,\n"," photo_generator_optimizer=photo_generator_optimizer,\n"," monet_discriminator_optimizer=monet_discriminator_optimizer,\n"," photo_discriminator_optimizer=photo_discriminator_optimizer,\n"," generator_loss_fn=generator_loss,\n"," discriminator_loss_fn=discriminator_loss,\n"," cycle_loss_fn=calc_cycle_loss,\n"," identity_loss_fn=identity_loss,\n"," aug_fn=aug_fn,\n"," )\n","\n","\n","def cycle_gan_compile_with_loss_rate(loss_rate: float) -> None:\n"," with strategy.scope():\n"," monet_generator_optimizer = tf.keras.optimizers.Adam(loss_rate, beta_1=0.5)\n"," photo_generator_optimizer = tf.keras.optimizers.Adam(loss_rate, beta_1=0.5)\n"," monet_discriminator_optimizer = tf.keras.optimizers.Adam(loss_rate, beta_1=0.5)\n"," photo_discriminator_optimizer = tf.keras.optimizers.Adam(loss_rate, beta_1=0.5)\n","\n"," cycle_gan_compile(monet_generator_optimizer,\n"," photo_generator_optimizer,\n"," monet_discriminator_optimizer,\n"," photo_discriminator_optimizer)\n"]},{"cell_type":"code","execution_count":53,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:31.916491Z","iopub.status.busy":"2023-02-18T22:25:31.916224Z","iopub.status.idle":"2023-02-18T22:26:10.585575Z","shell.execute_reply":"2023-02-18T22:26:10.584005Z","shell.execute_reply.started":"2023-02-18T22:25:31.916457Z"},"trusted":true},"outputs":[],"source":["if __name__ == '__main__':\n"," BATCH_SIZE = 32\n","\n"," full_dataset = get_gan_dataset(augment=data_augment_flip,\n"," shuffle=False,\n"," batch_size=BATCH_SIZE)\n","\n"," cycle_gan_compile_with_loss_rate(2e-4)\n"," cycle_gan_model.fit(full_dataset,\n"," epochs=1,\n"," steps_per_epoch=1)\n","\n"," cycle_gan_model.monet_generator.save_weights(f'photo2monet.h5')\n"," cycle_gan_model.photo_generator.save_weights(f'monet2photo.h5')\n"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2023-02-18T22:26:10.586849Z","iopub.status.idle":"2023-02-18T22:26:10.587271Z","shell.execute_reply":"2023-02-18T22:26:10.587067Z","shell.execute_reply.started":"2023-02-18T22:26:10.587046Z"},"trusted":true},"outputs":[],"source":["import sys\n","from pathlib import Path\n","from numpy import uint8\n","from PIL import Image\n","\n","\n","if __name__ == '__main__':\n"," photos = photo_dataset().batch(32 * strategy.num_replicas_in_sync).prefetch(32)\n"," try:\n"," cycle_gan_model.monet_generator.load_weights(f'photo2monet.h5')\n"," cycle_gan_model.photo_generator.load_weights(f'monet2photo.h5')\n"," except:\n"," sys.exit('Model not trained yet.')\n","\n"," Path('../submission_images').mkdir(parents=True, exist_ok=True)\n","\n"," i = 1\n"," for img in photos:\n"," prediction = cycle_gan_model.monet_generator(img, training=False)[0].numpy()\n"," prediction = (prediction * 127.5 + 127.5).astype(uint8)\n"," im = Image.fromarray(prediction)\n"," im.save(f'../submission_images/{i}.jpg')\n"," i += 1\n"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2023-02-18T22:26:10.588452Z","iopub.status.idle":"2023-02-18T22:26:10.588941Z","shell.execute_reply":"2023-02-18T22:26:10.588733Z","shell.execute_reply.started":"2023-02-18T22:26:10.588711Z"},"trusted":true},"outputs":[],"source":["import shutil\n","shutil.make_archive('/kaggle/working/images/', 'zip', '/kaggle/submission_images/')"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kernelspec":{"display_name":"monet-cyclegan-windows","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.16"},"vscode":{"interpreter":{"hash":"7586613da85f4e39e0fc23a5ffc537163f924aa60a8fb26eca5b6a25f5299bc3"}}},"nbformat":4,"nbformat_minor":4}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment