Created
February 19, 2023 18:35
-
-
Save airicbear/0a923f12ae4d1187095616f294550ad3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"cells":[{"cell_type":"code","execution_count":43,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.702897Z","iopub.status.busy":"2023-02-18T22:25:29.702552Z","iopub.status.idle":"2023-02-18T22:25:29.711940Z","shell.execute_reply":"2023-02-18T22:25:29.711234Z","shell.execute_reply.started":"2023-02-18T22:25:29.702844Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","\n","USING_KAGGLE = False\n","AUTOTUNE = tf.data.experimental.AUTOTUNE\n","\n","try:\n"," tpu = tf.distribute.cluster_resolver.TPUClusterResolver()\n"," tf.config.experimental_connect_to_cluster(tpu)\n"," tf.tpu.experimental.initialize_tpu_system(tpu)\n"," strategy = tf.distribute.experimental.TPUStrategy(tpu)\n","except:\n"," strategy = tf.distribute.get_strategy()"]},{"cell_type":"code","execution_count":44,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.715487Z","iopub.status.busy":"2023-02-18T22:25:29.714811Z","iopub.status.idle":"2023-02-18T22:25:29.724322Z","shell.execute_reply":"2023-02-18T22:25:29.723533Z","shell.execute_reply.started":"2023-02-18T22:25:29.715443Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","\n","def decode_image(image: tf.io.FixedLenFeature, image_size: 'list[int]' = [256, 256]) -> tf.Tensor:\n"," image = tf.image.decode_jpeg(image, channels=3)\n"," image = (tf.cast(image, tf.float32) / 127.5) - 1\n"," image = tf.reshape(image, [*image_size, 3])\n"," return image\n","\n","def read_tfrecord(example: tf.Tensor) -> tf.Tensor:\n"," tfrecord_format = {\n"," 'image': tf.io.FixedLenFeature([], tf.string),\n"," }\n"," example = tf.io.parse_single_example(example, tfrecord_format)\n"," image = decode_image(example['image'])\n"," return image\n","\n","def read_tfrecords(filenames: 'list[str]') -> tf.data.TFRecordDataset:\n"," dataset = tf.data.TFRecordDataset(filenames)\n"," dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)\n"," return dataset\n"]},{"cell_type":"code","execution_count":45,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.726118Z","iopub.status.busy":"2023-02-18T22:25:29.725782Z","iopub.status.idle":"2023-02-18T22:25:29.740727Z","shell.execute_reply":"2023-02-18T22:25:29.739957Z","shell.execute_reply.started":"2023-02-18T22:25:29.726086Z"},"trusted":true},"outputs":[],"source":["from typing import Callable\n","\n","import tensorflow as tf\n","import numpy as np\n","import re\n","\n","def dataset_path(path: str = '../Sample_Data/Raw') -> str:\n"," \"\"\"This function loads the Monet and photo images in `TFRecordDataset` format.\"\"\"\n"," if USING_KAGGLE:\n"," from kaggle_datasets import KaggleDatasets\n"," return KaggleDatasets().get_gcs_path()\n"," else:\n"," return path\n","\n","\n","def monet_filenames():\n"," \"\"\"This sub-function returns the filenames for the monet painting\"\"\"\n"," return tf.io.gfile.glob(f'{dataset_path()}/monet_tfrec/*.tfrec')\n","\n","\n","def photo_filenames():\n"," \"\"\"This sub-function returns the filenames for the photos\"\"\"\n"," return tf.io.gfile.glob(f'{dataset_path()}/photo_tfrec/*.tfrec')\n","\n","\n","def monet_dataset():\n"," \"\"\"This sub-function loads the Monet paintings in `TFRecordDataset` format\"\"\"\n"," return read_tfrecords(monet_filenames())\n","\n","\n","def photo_dataset():\n"," \"\"\"This sub-function loads the photos in `TFRecordDataset` format\"\"\"\n"," return read_tfrecords(photo_filenames())\n","\n","\n","def count_data_items(filenames) -> np.array:\n"," n: list\n","\n"," n = [int(re.compile(r\"-([0-9]*)\\.\").search(filename).group(1)) for filename in filenames]\n","\n"," return np.sum(n)\n","\n","\n","def count_monet_samples():\n"," return count_data_items(monet_filenames())\n","\n","\n","def count_photo_samples():\n"," return count_data_items(photo_filenames())\n","\n","\n","def get_gan_dataset(augment: Callable[[tf.Tensor], tf.Tensor] = None,\n"," repeat: bool = True,\n"," shuffle: bool = True,\n"," batch_size: int = 1):\n"," monets = monet_dataset()\n"," photos = photo_dataset()\n","\n"," if repeat:\n"," monets = monets.repeat()\n"," photos = photos.repeat()\n","\n"," if shuffle:\n"," monets = monets.shuffle(2048)\n"," photos = photos.shuffle(2048)\n","\n"," monets = monets.batch(batch_size, drop_remainder=True)\n"," photos = photos.batch(batch_size, drop_remainder=True)\n","\n"," if augment:\n"," monets = monets.map(augment, num_parallel_calls=AUTOTUNE)\n"," photos = photos.map(augment, num_parallel_calls=AUTOTUNE)\n","\n"," monets = monets.prefetch(AUTOTUNE)\n"," photos = photos.prefetch(AUTOTUNE)\n","\n"," gan_dataset = tf.data.Dataset.zip((monets, photos))\n","\n"," return gan_dataset\n"]},{"cell_type":"code","execution_count":46,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.742325Z","iopub.status.busy":"2023-02-18T22:25:29.742012Z","iopub.status.idle":"2023-02-18T22:25:29.754563Z","shell.execute_reply":"2023-02-18T22:25:29.753899Z","shell.execute_reply.started":"2023-02-18T22:25:29.742294Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","from tensorflow_addons.layers import InstanceNormalization\n","\n","def downsample(filters: int, size: int, apply_instancenorm: bool = True) -> tf.keras.Sequential:\n"," initializer = tf.random_normal_initializer(0., 0.02)\n"," gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)\n"," \n"," result = tf.keras.Sequential()\n"," result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))\n"," if apply_instancenorm:\n"," result.add(InstanceNormalization(gamma_initializer=gamma_init))\n"," result.add(tf.keras.layers.ReLU())\n","\n"," return result\n","\n","def upsample(filters: int, size: int, apply_dropout: bool = False) -> tf.keras.Sequential:\n"," initializer = tf.random_normal_initializer(0., 0.02)\n"," gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)\n","\n"," result = tf.keras.Sequential()\n"," result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))\n"," result.add(InstanceNormalization(gamma_initializer=gamma_init))\n"," if apply_dropout:\n"," result.add(tf.keras.layers.Dropout(0.5))\n"," result.add(tf.keras.layers.ReLU())\n","\n"," return result"]},{"cell_type":"code","execution_count":47,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.756689Z","iopub.status.busy":"2023-02-18T22:25:29.756423Z","iopub.status.idle":"2023-02-18T22:25:29.768980Z","shell.execute_reply":"2023-02-18T22:25:29.768248Z","shell.execute_reply.started":"2023-02-18T22:25:29.756658Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","from typing import List\n","\n","def Generator(output_channels: int = 3) -> tf.keras.Model:\n"," \"\"\"\n"," The Generator applies an encoder-decoder architecture that downsamples the image and then decodes the image\n"," through various transpose convolutions.\" It is inspired by a \"U-Net\" architecture for image generation, which downsamples\n"," an image and then applies \n"," \"\"\"\n"," down_stack = [\n"," downsample(64, 4, apply_instancenorm=False),\n"," downsample(128, 4),\n"," downsample(256, 4),\n"," downsample(512, 4),\n"," downsample(512, 4),\n"," downsample(512, 4),\n"," downsample(512, 4),\n"," downsample(512, 4),\n"," ]\n","\n"," up_stack = [\n"," upsample(512, 4, apply_dropout=True),\n"," upsample(512, 4, apply_dropout=True),\n"," upsample(512, 4, apply_dropout=True),\n"," upsample(512, 4),\n"," upsample(256, 4),\n"," upsample(128, 4),\n"," upsample(64, 4),\n"," ]\n","\n"," initializer = tf.random_normal_initializer(0., 0.02)\n","\n"," inputs = tf.keras.layers.Input(shape=[256,256,3])\n"," x = inputs\n","\n"," skips: List[tf.keras.Sequential] = []\n"," for down in down_stack:\n"," x = down(x)\n"," skips.append(x)\n","\n"," skips = list(reversed(skips[:-1]))\n","\n"," for up, skip in zip(up_stack, skips):\n"," x = up(x)\n"," x = tf.keras.layers.Concatenate()([x, skip])\n","\n"," x = tf.keras.layers.Conv2DTranspose(output_channels,\n"," 4,\n"," strides=2,\n"," padding='same',\n"," kernel_initializer=initializer,\n"," activation='tanh')(x)\n","\n"," return tf.keras.Model(inputs=inputs, outputs=x)\n"]},{"cell_type":"code","execution_count":48,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.770875Z","iopub.status.busy":"2023-02-18T22:25:29.770111Z","iopub.status.idle":"2023-02-18T22:25:29.782620Z","shell.execute_reply":"2023-02-18T22:25:29.781880Z","shell.execute_reply.started":"2023-02-18T22:25:29.770841Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","from tensorflow_addons.layers import InstanceNormalization\n","\n","def Discriminator() -> tf.keras.Model:\n"," initializer = tf.random_normal_initializer(0., 0.02)\n"," gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)\n","\n"," inputs = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')\n","\n"," x = inputs\n"," down1 = downsample(64, 4, False)(x)\n"," down2 = downsample(128, 4)(down1)\n"," down3 = downsample(256, 4)(down2)\n"," zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)\n"," conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1)\n"," norm1 = InstanceNormalization(gamma_initializer=gamma_init)(conv)\n"," leaky_relu = tf.keras.layers.LeakyReLU()(norm1)\n"," zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)\n"," outputs = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2)\n","\n"," return tf.keras.Model(inputs=inputs, outputs=outputs)\n"]},{"cell_type":"code","execution_count":49,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.786164Z","iopub.status.busy":"2023-02-18T22:25:29.785957Z","iopub.status.idle":"2023-02-18T22:25:29.817426Z","shell.execute_reply":"2023-02-18T22:25:29.816647Z","shell.execute_reply.started":"2023-02-18T22:25:29.786140Z"},"trusted":true},"outputs":[],"source":["from typing import Callable, Dict, List\n","\n","import tensorflow as tf\n","\n","with strategy.scope():\n"," def diff_augment(x: tf.Tensor,\n"," policy: str = '',\n"," channels_first: bool = False) -> tf.Tensor:\n"," if policy:\n"," if channels_first:\n"," x = tf.transpose(x, [0, 2, 3, 1])\n"," for p in policy.split(','):\n"," for fn in AUGMENT_FNS[p]:\n"," x = fn(x)\n"," if channels_first:\n"," x = tf.transpose(x, [0, 3, 1, 2])\n"," return x\n","\n","\n"," def rand_brightness(x: tf.Tensor) -> tf.Tensor:\n"," magnitude: tf.Tensor\n","\n"," magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) - 0.5\n"," x = x + magnitude\n","\n"," return x\n","\n","\n"," def rand_saturation(x: tf.Tensor) -> tf.Tensor:\n"," magnitude: tf.Tensor\n"," x_mean: tf.Tensor\n","\n"," magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) * 2\n"," x_mean = tf.reduce_mean(x, axis=3, keepdims=True) * 0.3333333333333333333\n"," x = (x - x_mean) * magnitude + x_mean\n","\n"," return x\n","\n","\n"," def rand_contrast(x: tf.Tensor) -> tf.Tensor:\n"," magnitude: tf.Tensor\n"," x_mean: tf.Tensor\n","\n"," magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) + 0.5\n"," x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True) * 5.086e-6\n"," x = (x - x_mean) * magnitude + x_mean\n","\n"," return x\n","\n","\n"," def rand_translation(x: tf.Tensor,\n"," ratio: float = 0.125) -> tf.Tensor:\n"," batch_size: tf.Tensor\n"," image_size: tf.Tensor\n"," shift: tf.Tensor\n"," translation_x: tf.Tensor\n"," translation_y: tf.Tensor\n"," grid_x: tf.Tensor\n"," grid_y: tf.Tensor\n","\n"," batch_size = tf.shape(x)[0]\n"," image_size = tf.shape(x)[1:3]\n"," shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)\n"," translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32)\n"," translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32)\n"," grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1,\n"," 0,\n"," image_size[0] + 1)\n"," grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1,\n"," 0,\n"," image_size[1] + 1)\n"," x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1)\n"," x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]),\n"," tf.expand_dims(grid_y, -1),\n"," batch_dims=1),\n"," [0, 2, 1, 3])\n","\n"," return x\n","\n","\n"," def rand_cutout(x: tf.Tensor,\n"," ratio: float = 0.5) -> tf.Tensor:\n"," batch_size: tf.Tensor\n"," image_size: tf.Tensor\n"," cutout_size: tf.Tensor\n"," offset_x: tf.Tensor\n"," offset_y: tf.Tensor\n"," grid_batch: tf.Tensor\n"," grid_x: tf.Tensor\n"," grid_y: tf.Tensor\n"," cutout_grid: tf.Tensor\n"," mask: tf.Tensor\n","\n"," batch_size = tf.shape(x)[0]\n"," image_size = tf.shape(x)[1:3]\n"," cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)\n"," offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1],\n"," maxval=image_size[0] + (1 - cutout_size[0] % 2),\n"," dtype=tf.int32)\n"," offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1],\n"," maxval=image_size[1] + (1 - cutout_size[1] % 2),\n"," dtype=tf.int32)\n"," grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32),\n"," tf.range(cutout_size[0], dtype=tf.int32),\n"," tf.range(cutout_size[1], dtype=tf.int32),\n"," indexing='ij')\n"," cutout_grid = tf.stack([grid_batch,\n"," grid_x + offset_x - cutout_size[0] // 2,\n"," grid_y + offset_y - cutout_size[1] // 2],\n"," axis=-1)\n"," mask_shape = tf.stack([batch_size, image_size[0], image_size[1]])\n"," cutout_grid = tf.maximum(cutout_grid, 0)\n"," cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3]))\n"," mask = tf.maximum(1 - tf.scatter_nd(cutout_grid,\n"," tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32),\n"," mask_shape),\n"," 0)\n"," x = x * tf.expand_dims(mask, axis=3)\n","\n"," return x\n","\n","\n"," AUGMENT_FNS: Dict[str, List[Callable[[tf.Tensor], tf.Tensor]]] = {\n"," 'color': [rand_brightness, rand_saturation, rand_contrast],\n"," 'translation': [rand_translation],\n"," 'cutout': [rand_cutout],\n"," }\n","\n","\n"," def aug_fn(image: tf.Tensor) -> tf.Tensor:\n"," return diff_augment(image,\n"," \"color,translation,cutout\")\n","\n","\n","def data_augment_flip(image: tf.Tensor) -> tf.Tensor:\n"," image = tf.image.random_flip_left_right(image)\n"," return image\n"]},{"cell_type":"code","execution_count":50,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.819262Z","iopub.status.busy":"2023-02-18T22:25:29.818731Z","iopub.status.idle":"2023-02-18T22:25:29.839088Z","shell.execute_reply":"2023-02-18T22:25:29.838394Z","shell.execute_reply.started":"2023-02-18T22:25:29.819227Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","from typing import Callable\n","\n","\n","class CycleGan(tf.keras.Model):\n"," \"\"\"\n"," This class is the CylceGAN model, which initializes and compiles the generators and discriminators\n"," \"\"\"\n"," \n"," def __init__(self,\n"," monet_generator: tf.keras.Model,\n"," photo_generator: tf.keras.Model,\n"," monet_discriminator: tf.keras.Model,\n"," photo_discriminator: tf.keras.Model,\n"," lambda_cycle: int = 3,\n"," lambda_id: int = 3,\n"," ):\n"," \"\"\"\n"," This is the initialization function for the generators, discriminators, and the lambda cycle\n"," \"\"\"\n"," super(CycleGan, self).__init__()\n"," self.monet_generator = monet_generator\n"," self.photo_generator = photo_generator\n"," self.monet_discriminator = monet_discriminator\n"," self.photo_discriminator = photo_discriminator\n"," self.lambda_cycle = lambda_cycle\n"," self.lambda_id = lambda_id\n"," \n"," def compile(self,\n"," monet_generator_optimizer: tf.keras.optimizers.Optimizer,\n"," photo_generator_optimizer: tf.keras.optimizers.Optimizer,\n"," monet_discriminator_optimizer: tf.keras.optimizers.Optimizer,\n"," photo_discriminator_optimizer: tf.keras.optimizers.Optimizer,\n"," generator_loss_fn: Callable[[tf.keras.Model], tf.Tensor],\n"," discriminator_loss_fn: Callable[[tf.keras.Model, tf.keras.Model], tf.Tensor],\n"," cycle_loss_fn: Callable[[tf.Tensor, tf.Tensor, float], tf.Tensor],\n"," identity_loss_fn: Callable[[tf.Tensor, tf.Tensor, float], tf.Tensor],\n"," aug_fn):\n"," \"\"\"\n"," This function sets the optimizers and the loss functions\n"," \"\"\"\n"," super(CycleGan, self).compile()\n"," self.monet_generator_optimizer = monet_generator_optimizer\n"," self.photo_generator_optimizer = photo_generator_optimizer\n"," self.monet_discriminator_optimizer = monet_discriminator_optimizer\n"," self.photo_discriminator_optimizer = photo_discriminator_optimizer\n"," self.generator_loss_fn = generator_loss_fn\n"," self.discriminator_loss_fn = discriminator_loss_fn\n"," self.cycle_loss_fn = cycle_loss_fn\n"," self.identity_loss_fn = identity_loss_fn\n"," self.aug_fn = aug_fn\n"," \n","\n","\n"," @tf.function\n"," def train_step(self, batch_data: 'tuple[tf.Tensor, tf.Tensor]'):\n"," real_monet, real_photo = batch_data\n"," batch_size = tf.shape(real_monet)[0]\n","\n"," with tf.GradientTape(persistent=True) as tape:\n"," fake_monet = self.monet_generator(real_photo, training=True)\n"," cycled_photo = self.photo_generator(fake_monet, training=True)\n","\n"," fake_photo = self.photo_generator(real_monet, training=True)\n"," cycled_monet = self.monet_generator(fake_photo, training=True)\n","\n"," same_monet = self.monet_generator(real_monet, training=True)\n"," same_photo = self.photo_generator(real_photo, training=True)\n","\n"," both_monet = tf.concat([real_monet, fake_monet], axis=0)\n","\n"," aug_monet = self.aug_fn(both_monet)\n","\n"," aug_real_monet = aug_monet[:batch_size]\n"," aug_fake_monet = aug_monet[batch_size:]\n","\n"," discriminator_real_monet = self.monet_discriminator(aug_real_monet, training=True)\n"," discriminator_real_photo = self.photo_discriminator(real_photo, training=True)\n","\n"," discriminator_fake_monet = self.monet_discriminator(aug_fake_monet, training=True)\n"," discriminator_fake_photo = self.photo_discriminator(fake_photo, training=True)\n","\n"," monet_generator_loss = self.generator_loss_fn(discriminator_fake_monet)\n"," photo_generator_loss = self.generator_loss_fn(discriminator_fake_photo)\n","\n"," total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle / tf.cast(batch_size, tf.float32)) \\\n"," + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle / tf.cast(batch_size, tf.float32))\n","\n"," total_monet_generator_loss = monet_generator_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_id / tf.cast(batch_size, tf.float32))\n"," total_photo_generator_loss = photo_generator_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_id / tf.cast(batch_size, tf.float32))\n","\n"," monet_discriminator_loss = self.discriminator_loss_fn(discriminator_real_monet, discriminator_fake_monet)\n"," photo_discriminator_loss = self.discriminator_loss_fn(discriminator_real_photo, discriminator_fake_photo)\n","\n"," monet_generator_gradients = tape.gradient(total_monet_generator_loss, self.monet_generator.trainable_variables)\n"," photo_generator_gradients = tape.gradient(total_photo_generator_loss, self.photo_generator.trainable_variables)\n"," \n"," monet_discriminator_gradients = tape.gradient(monet_discriminator_loss, self.monet_discriminator.trainable_variables)\n"," photo_discriminator_gradients = tape.gradient(photo_discriminator_loss, self.photo_discriminator.trainable_variables)\n","\n"," self.monet_generator_optimizer.apply_gradients(zip(monet_generator_gradients, self.monet_generator.trainable_variables))\n"," self.photo_generator_optimizer.apply_gradients(zip(photo_generator_gradients, self.photo_generator.trainable_variables))\n","\n"," self.monet_discriminator_optimizer.apply_gradients(zip(monet_discriminator_gradients, self.monet_discriminator.trainable_variables))\n"," self.photo_discriminator_optimizer.apply_gradients(zip(photo_discriminator_gradients, self.photo_discriminator.trainable_variables))\n","\n"," return {\n"," 'monet_generator_loss': total_monet_generator_loss,\n"," 'photo_generator_loss': total_photo_generator_loss,\n"," 'monet_discriminator_loss': monet_discriminator_loss,\n"," 'photo_discriminator_loss': photo_discriminator_loss\n"," }"]},{"cell_type":"code","execution_count":51,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.841110Z","iopub.status.busy":"2023-02-18T22:25:29.840495Z","iopub.status.idle":"2023-02-18T22:25:29.852611Z","shell.execute_reply":"2023-02-18T22:25:29.851937Z","shell.execute_reply.started":"2023-02-18T22:25:29.841070Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","\n","with strategy.scope():\n"," def discriminator_loss(real: tf.keras.Model, generated: tf.keras.Model) -> tf.Tensor:\n"," real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True,\n"," reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real),\n"," real)\n"," generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True,\n"," reduction=tf.keras.losses.Reduction.NONE)(\n"," tf.zeros_like(generated), generated)\n"," total_discriminator_loss = real_loss + generated_loss\n"," return total_discriminator_loss * 0.5\n","\n","\n"," def generator_loss(generated: tf.keras.Model) -> tf.Tensor:\n"," return tf.keras.losses.BinaryCrossentropy(from_logits=True,\n"," reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated),\n"," generated)\n","\n","\n"," def calc_cycle_loss(real_image: tf.Tensor, cycled_image: tf.Tensor, alpha: float) -> tf.Tensor:\n"," loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))\n"," return alpha * loss1\n","\n","\n"," def identity_loss(real_image: tf.Tensor, same_image: tf.Tensor, alpha: float) -> tf.Tensor:\n"," loss = tf.reduce_mean(tf.abs(real_image - same_image))\n"," return alpha * 0.5 * loss\n"]},{"cell_type":"code","execution_count":52,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:29.854252Z","iopub.status.busy":"2023-02-18T22:25:29.853971Z","iopub.status.idle":"2023-02-18T22:25:31.914875Z","shell.execute_reply":"2023-02-18T22:25:31.914116Z","shell.execute_reply.started":"2023-02-18T22:25:29.854207Z"},"trusted":true},"outputs":[],"source":["import tensorflow as tf\n","\n","monet_generator = Generator()\n","photo_generator = Generator()\n","monet_discriminator = Discriminator()\n","photo_discriminator = Discriminator()\n","\n","cycle_gan_model = CycleGan(monet_generator, photo_generator, monet_discriminator, photo_discriminator)\n","\n","\n","def cycle_gan_compile(monet_generator_optimizer: tf.keras.optimizers.Optimizer,\n"," photo_generator_optimizer: tf.keras.optimizers.Optimizer,\n"," monet_discriminator_optimizer: tf.keras.optimizers.Optimizer,\n"," photo_discriminator_optimizer: tf.keras.optimizers.Optimizer) -> None:\n"," cycle_gan_model.compile(\n"," monet_generator_optimizer=monet_generator_optimizer,\n"," photo_generator_optimizer=photo_generator_optimizer,\n"," monet_discriminator_optimizer=monet_discriminator_optimizer,\n"," photo_discriminator_optimizer=photo_discriminator_optimizer,\n"," generator_loss_fn=generator_loss,\n"," discriminator_loss_fn=discriminator_loss,\n"," cycle_loss_fn=calc_cycle_loss,\n"," identity_loss_fn=identity_loss,\n"," aug_fn=aug_fn,\n"," )\n","\n","\n","def cycle_gan_compile_with_loss_rate(loss_rate: float) -> None:\n"," with strategy.scope():\n"," monet_generator_optimizer = tf.keras.optimizers.Adam(loss_rate, beta_1=0.5)\n"," photo_generator_optimizer = tf.keras.optimizers.Adam(loss_rate, beta_1=0.5)\n"," monet_discriminator_optimizer = tf.keras.optimizers.Adam(loss_rate, beta_1=0.5)\n"," photo_discriminator_optimizer = tf.keras.optimizers.Adam(loss_rate, beta_1=0.5)\n","\n"," cycle_gan_compile(monet_generator_optimizer,\n"," photo_generator_optimizer,\n"," monet_discriminator_optimizer,\n"," photo_discriminator_optimizer)\n"]},{"cell_type":"code","execution_count":53,"metadata":{"execution":{"iopub.execute_input":"2023-02-18T22:25:31.916491Z","iopub.status.busy":"2023-02-18T22:25:31.916224Z","iopub.status.idle":"2023-02-18T22:26:10.585575Z","shell.execute_reply":"2023-02-18T22:26:10.584005Z","shell.execute_reply.started":"2023-02-18T22:25:31.916457Z"},"trusted":true},"outputs":[],"source":["if __name__ == '__main__':\n"," BATCH_SIZE = 32\n","\n"," full_dataset = get_gan_dataset(augment=data_augment_flip,\n"," shuffle=False,\n"," batch_size=BATCH_SIZE)\n","\n"," cycle_gan_compile_with_loss_rate(2e-4)\n"," cycle_gan_model.fit(full_dataset,\n"," epochs=1,\n"," steps_per_epoch=1)\n","\n"," cycle_gan_model.monet_generator.save_weights(f'photo2monet.h5')\n"," cycle_gan_model.photo_generator.save_weights(f'monet2photo.h5')\n"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2023-02-18T22:26:10.586849Z","iopub.status.idle":"2023-02-18T22:26:10.587271Z","shell.execute_reply":"2023-02-18T22:26:10.587067Z","shell.execute_reply.started":"2023-02-18T22:26:10.587046Z"},"trusted":true},"outputs":[],"source":["import sys\n","from pathlib import Path\n","from numpy import uint8\n","from PIL import Image\n","\n","\n","if __name__ == '__main__':\n"," photos = photo_dataset().batch(32 * strategy.num_replicas_in_sync).prefetch(32)\n"," try:\n"," cycle_gan_model.monet_generator.load_weights(f'photo2monet.h5')\n"," cycle_gan_model.photo_generator.load_weights(f'monet2photo.h5')\n"," except:\n"," sys.exit('Model not trained yet.')\n","\n"," Path('../submission_images').mkdir(parents=True, exist_ok=True)\n","\n"," i = 1\n"," for img in photos:\n"," prediction = cycle_gan_model.monet_generator(img, training=False)[0].numpy()\n"," prediction = (prediction * 127.5 + 127.5).astype(uint8)\n"," im = Image.fromarray(prediction)\n"," im.save(f'../submission_images/{i}.jpg')\n"," i += 1\n"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2023-02-18T22:26:10.588452Z","iopub.status.idle":"2023-02-18T22:26:10.588941Z","shell.execute_reply":"2023-02-18T22:26:10.588733Z","shell.execute_reply.started":"2023-02-18T22:26:10.588711Z"},"trusted":true},"outputs":[],"source":["import shutil\n","shutil.make_archive('/kaggle/working/images/', 'zip', '/kaggle/submission_images/')"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kernelspec":{"display_name":"monet-cyclegan-windows","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.16"},"vscode":{"interpreter":{"hash":"7586613da85f4e39e0fc23a5ffc537163f924aa60a8fb26eca5b6a25f5299bc3"}}},"nbformat":4,"nbformat_minor":4} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment