Last active
December 23, 2022 19:49
-
-
Save eyaler/0cee9a71f5dd3fdfa9c0c03656ebdd4c to your computer and use it in GitHub Desktop.
ruDALLE-Outpainting.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/eyaler/0cee9a71f5dd3fdfa9c0c03656ebdd4c/rudalle-outpainting.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Py2MOWx5kNxH" | |
}, | |
"source": [ | |
"# ruDALL-E Outpainting\n", | |
"\n", | |
"ruDALL-E article: https://habr.com/ru/company/sberbank/blog/589673\n", | |
"\n", | |
"Original notebook: https://colab.research.google.com/github/sberbank-ai/ru-dalle/blob/master/jupyters/ruDALLE-image-prompts-A100.ipynb\n", | |
"\n", | |
"Arbitraty resolution notebook (not implemented here yet): https://colab.research.google.com/drive/1DbqOIUIVBPOrJ4MeaV4YkAlb7ilWQjKZ\n", | |
"\n", | |
"Inspired by: https://twitter.com/MichaelFriese10/status/1456023409213726725\n", | |
"\n", | |
"Experiments twitter thread I: https://twitter.com/eyaler/status/1468682110860992521\n", | |
"\n", | |
"Experiments twitter thread II: https://twitter.com/eyaler/status/1470150704488660993\n", | |
"\n", | |
"More results image gallery: https://imgur.com/gallery/tcwYSzM\n", | |
"\n", | |
"Shortcut to this notebook: [j.mp/outpaint](https://j.mp/outpaint)\n", | |
"\n", | |
"Notebook by: [Eyal Gruss](https://eyalgruss.com) \\([@eyaler](https://twitter.com/eyaler)\\)\n", | |
"\n", | |
"A curated list of online generative tools: [j.mp/generativetools](https://j.mp/generativetools)" | |
], | |
"id": "Py2MOWx5kNxH" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "118b2319", | |
"cellView": "form" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Setup {run: 'auto'}\n", | |
"\n", | |
"!pip install rudalle==1.0.0 > /dev/null 2>&1\n", | |
"!pip install ruclip==0.0.1 > /dev/null 2>&1\n", | |
"!pip install translators==4.11.0 > /dev/null 2>&1\n", | |
"\n", | |
"\n", | |
"from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_ruclip\n", | |
"from rudalle.image_prompts import ImagePrompts\n", | |
"from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan\n", | |
"from rudalle.utils import seed_everything\n", | |
"import ruclip\n", | |
"from PIL import Image, ImageOps\n", | |
"import torch\n", | |
"from google.colab import files, _message\n", | |
"import numpy as np\n", | |
"import translators\n", | |
"import requests\n", | |
"import os\n", | |
"import json\n", | |
"\n", | |
"\n", | |
"model = 'Malevich_v3' #@param ['Malevich_v3', 'Malevich_v2', 'Emojich']\n", | |
"if model == 'Malevich_v3':\n", | |
" model = 'Malevich'\n", | |
"dalle = get_rudalle_model(model, pretrained=True, fp16=True, device='cuda')\n", | |
"realesrgan = {x: get_realesrgan('x%d'%x, device='cuda') for x in [2,4,8]} \n", | |
"tokenizer = get_tokenizer()\n", | |
"vae = get_vae().to('cuda')\n", | |
"dwt_vae = get_vae(dwt=True).to('cuda')\n", | |
"clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device='cuda')\n", | |
"clip_predictor = ruclip.Predictor(clip, processor, device='cuda')\n" | |
], | |
"id": "118b2319" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "81b664b6" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Upload images\n", | |
"#@markdown Click button below to select files or cancel to use default image\n", | |
"\n", | |
"#@markdown Note: Works best for square images. Consider padding/cropping your images beforehand\n", | |
"fallback_image_url = 'https://web.archive.org/web/20210213021410if_/https://www.gallerypop.co.uk/wp-content/uploads/2017/11/Soup.jpg'\n", | |
"try:\n", | |
" filenames\n", | |
"except Exception:\n", | |
" filenames = []\n", | |
"streams = []\n", | |
"try:\n", | |
" new_filenames = files.upload()\n", | |
"except Exception:\n", | |
" pass\n", | |
"else:\n", | |
" if new_filenames:\n", | |
" filenames = new_filenames\n", | |
"if not filenames:\n", | |
" streams = [requests.get(fallback_image_url, stream=True).raw]\n", | |
"orig_images = [ImageOps.exif_transpose(Image.open(f)) for f in filenames or streams]\n", | |
"if not filenames:\n", | |
" filenames = [fallback_image_url.rsplit('/',1)[-1]]" | |
], | |
"id": "81b664b6" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "cebf4449", | |
"cellView": "form" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Set options and run\n", | |
"#@markdown Take top part and complete bottom part (downpainting):\n", | |
"take_top = True #@param {type:'boolean'}\n", | |
"top_frac = 0.5 #@param {type:'slider', max:1, step:0.03125}\n", | |
"take_bottom = False\n", | |
"bottom_frac = 0.5\n", | |
"#take_bottom = False #@param {type:'boolean'}\n", | |
"#bootom_frac = 0.5 #@param {type:'slider', max:1, step:0.03125}\n", | |
"#@markdown Add additional runs taking different image parts (will probably not work so good):\n", | |
"flip_and_take_top = False #@param {type:'boolean'}\n", | |
"flipped_top_frac = 0.5 #@param {type:'slider', max:1, step:0.03125}\n", | |
"flip_result_back = 'no' #@param ['no', 'before_clip', 'finally']\n", | |
"take_left = False #@param {type:'boolean'}\n", | |
"left_frac = 0.5 #@param {type:'slider', max:1, step:0.03125}\n", | |
"take_right = False #@param {type:'boolean'}\n", | |
"right_frac = 0.5 #@param {type:'slider', max:1, step:0.03125}\n", | |
"#@markdown before_encoding will give less constrained results adequate for natual images but will give noncontextual completions otherwise:\n", | |
"crop_order = 'after_encoding' #@param ['after_encoding', 'before_encoding', 'both_after_and_before_runs']\n", | |
"#@markdown Increase for more diverse results, decrease for less:\n", | |
"temperature = 1#@param {type:'number'}\n", | |
"#@markdown Increase these for more outputs and better results, but will take longer. decrease for quick and dirty results:\n", | |
"num_levels = 1#@param {type:'slider', min:1, max:9}\n", | |
"retries_per_level = 4#@param {type:'integer'}\n", | |
"order_of_levels = 'high_to_low' #@param ['high_to_low', 'low_to_high', 'interleaved_extreme_to_middle', 'interleaved_middle_to_extreme']\n", | |
"#@markdown Optional text prompt (can leave empty; will auto-translate from any language to Russian - check if the back translation is OK):\n", | |
"text = '' #@param {type:'string'}\n", | |
"additional_run_without_text = False #@param {type:'boolean'}\n", | |
"#@markdown Output options:\n", | |
"fix_aspect_ratio = 'before_clip' #@param ['no', 'before_clip', 'after_clip']\n", | |
"paste_original_part_over_generated = 'no' #@param ['no', 'before_clip', 'between_clip_and_super_resolution', 'after_super_resolution','before_clip_and_again_after_super_resolution']\n", | |
"blend_frac = 0.1 #@param {type:'slider', max:1, step:0.005}\n", | |
"dwt_decoder_upscale = False #@param {type:'boolean'}\n", | |
"super_resolution_factor = 4 #@param [1, 2, 4, 8] {type:'raw'}\n", | |
"limit_display_results = 0#@param {type:'integer'}\n", | |
"display_width = 30 #@param {type:'number'}\n", | |
"#@markdown Set to a positive number for reproducing results (different results for different numbers):\n", | |
"random_seed = 42 #@param {type:'integer'}\n", | |
"if random_seed < 1:\n", | |
" random_seed = None\n", | |
"if random_seed:\n", | |
" seed_everything(random_seed)\n", | |
"if limit_display_results < 1 or limit_display_results > num_levels * retries_per_level:\n", | |
" limit_display_results = num_levels * retries_per_level\n", | |
"nrow = int(np.ceil((limit_display_results+2)/np.ceil((limit_display_results+2)/6)))\n", | |
"display_width = 30 \n", | |
"crop_orders = [0,1]\n", | |
"if crop_order == 'after_encoding':\n", | |
" crop_orders = [0]\n", | |
"elif crop_order == 'before_encoding':\n", | |
" crop_orders = [1]\n", | |
"assert take_top or flip_and_take_top or take_bottom or take_left or take_right, 'Select at least one of the take options'\n", | |
"\n", | |
"enc_size = 32\n", | |
"\n", | |
"def crop(im, borders, mask=None):\n", | |
" im = np.array(im)\n", | |
" if mask is not None:\n", | |
" mask = np.array(mask)\n", | |
" assert im.shape == mask.shape, (im.shape, mask.shape) \n", | |
" if borders['up'] is not None:\n", | |
" i = int(round(im.shape[0]*borders['up']/enc_size))\n", | |
" z = max(i-int(round(blend_frac*im.shape[0])), 0)\n", | |
" if mask is None:\n", | |
" im[i:] = 255\n", | |
" else:\n", | |
" im[:z] = mask[:z]\n", | |
" for j in range(z, i):\n", | |
" alpha = (j-z+1)/(i-z+1)\n", | |
" im[j] = im[j]*alpha+mask[j]*(1-alpha)\n", | |
" elif borders['down'] is not None:\n", | |
" i = int(round(im.shape[0]*borders['down']/enc_size))\n", | |
" z = min(i+int(round(blend_frac*im.shape[0])), im.shape[0])\n", | |
" if mask is None:\n", | |
" im[:i] = 255\n", | |
" else:\n", | |
" im[z:] = mask[z:]\n", | |
" for j in range(i, z):\n", | |
" alpha = (j-i+1)/(z-i+1)\n", | |
" im[j] = im[j]*(1-alpha)+mask[i:z]*alpha\n", | |
" elif borders['left'] is not None:\n", | |
" i = int(round(im.shape[1]*borders['left']/enc_size))\n", | |
" z = max(i-int(round(blend_frac*im.shape[1])), 0)\n", | |
" if mask is None:\n", | |
" im[:,i:] = 255\n", | |
" else:\n", | |
" im[:,:z] = mask[:,:z]\n", | |
" for j in range(z, i):\n", | |
" alpha = (j-z+1)/(i-z+1)\n", | |
" im[:,j] = im[:,j]*alpha+mask[:,j]*(1-alpha)\n", | |
" elif borders['right'] is not None:\n", | |
" i = int(round(im.shape[1]*borders['right']/enc_size))\n", | |
" z = min(i+int(round(blend_frac*im.shape[1])), im.shape[1])\n", | |
" if mask is None:\n", | |
" im[:,:i] = 255\n", | |
" else:\n", | |
" im[:,z:] = mask[:,z:]\n", | |
" for j in range(i, z):\n", | |
" alpha = (j-i+1)/(z-i+1)\n", | |
" im[:,j] = im[:,j]*(1-alpha)+mask[:,j]*alpha\n", | |
" im = Image.fromarray(im)\n", | |
" return im\n", | |
"\n", | |
"images = [im.resize((256,256)) for im in orig_images]\n", | |
"borders_flips = []\n", | |
"if take_top:\n", | |
" borders_flips.append(({'up': int(round(enc_size*top_frac)), 'left': None, 'right': None, 'down': None}, False))\n", | |
"if flip_and_take_top:\n", | |
" borders_flips.append(({'up': int(round(enc_size*flipped_top_frac)), 'left': None, 'right': None, 'down': None}, True))\n", | |
"if take_bottom:\n", | |
" borders_flips.append(({'up': None, 'left': None, 'right': None, 'down': int(round(enc_size*bottom_frac))}, False))\n", | |
"if take_left:\n", | |
" borders_flips.append(({'up': None, 'left': int(round(enc_size*left_frac)), 'right': None, 'down': None}, False))\n", | |
"if take_right:\n", | |
" borders_flips.append(({'up': None, 'left': None, 'right': int(round(enc_size*right_frac)), 'down': None}, False))\n", | |
"\n", | |
"def simple_detect_lang(text):\n", | |
" if len(set('абвгдежзийклмнопрстуфхцчшщъыьэюяё').intersection(text.lower())):\n", | |
" return 'ru'\n", | |
" if len(set('אבגדהוזחטיכךלמםנןסעפצץקרשת').intersection(text)):\n", | |
" return 'iw'\n", | |
" if len(set('abcdefghijklmnopqrstuvwxyz').intersection(text.lower())):\n", | |
" return 'en'\n", | |
" return 'auto'\n", | |
"\n", | |
"if text:\n", | |
" orig_text = text\n", | |
" lang = simple_detect_lang(text)\n", | |
" if lang != 'ru':\n", | |
" text = translators.google(text, from_language=lang, to_language='ru')\n", | |
" back_text = translators.google(text, from_language='ru', to_language=lang if lang not in ['auto', 'ru'] else 'en')\n", | |
" print('original text:', orig_text)\n", | |
" print('language detected:', lang)\n", | |
" print('prompt in russian:', text)\n", | |
" print('back translation:' if lang not in ['auto','ru'] else 'english translation', back_text)\n", | |
"texts = [text]\n", | |
"if text and additional_run_without_text:\n", | |
" texts.append('')\n", | |
"\n", | |
"levels = [\n", | |
" (2048, 0.995),\n", | |
" (1024, 0.98),\n", | |
" (1536, 0.99),\n", | |
" (1024, 0.99),\n", | |
" (512, 0.97),\n", | |
" (384, 0.96),\n", | |
" (256, 0.95),\n", | |
" (128, 0.95),\n", | |
" (64, 0.92),\n", | |
" ]\n", | |
"level_indices = list(range(len(levels)))\n", | |
"if order_of_levels == 'low_to_high':\n", | |
" level_indices.reverse()\n", | |
"elif order_of_levels.startswith('interleaved'):\n", | |
" level_indices = [i for pair in zip(level_indices, reversed(level_indices)) for i in pair][:len(levels)]\n", | |
" if order_of_levels == 'interleaved_middle_to_extreme':\n", | |
" level_indices.reverse()\n", | |
"\n", | |
"save_dir = None\n", | |
"save_dirs = []\n", | |
"os.makedirs('/content/output', exist_ok=True)\n", | |
"notebook = _message.blocking_request('get_ipynb', timeout_sec=60)\n", | |
"all_hires = []\n", | |
"for j, (image, orig, filename) in enumerate(zip(images, orig_images, filenames), start=1):\n", | |
" print('%d/%d: %s'%(j,len(filenames),filename))\n", | |
" for borders, flip in borders_flips:\n", | |
" if flip:\n", | |
" image = image.transpose(Image.FLIP_TOP_BOTTOM)\n", | |
" orig = orig.transpose(Image.FLIP_TOP_BOTTOM)\n", | |
" for text_to_use in texts:\n", | |
" for crop_first in crop_orders:\n", | |
" out_images = []\n", | |
" scores = []\n", | |
" image_prompt = ImagePrompts(image, {k: v or 0 for k, v in borders.items()}, dwt_vae if dwt_decoder_upscale else vae, device='cuda', crop_first=crop_first)\n", | |
" for i in level_indices[:num_levels]:\n", | |
" top_k, top_p = levels[i]\n", | |
" _pil_images, _scores = generate_images(\n", | |
" text_to_use,\n", | |
" tokenizer,\n", | |
" dalle,\n", | |
" dwt_vae if dwt_decoder_upscale else vae,\n", | |
" top_k=top_k,\n", | |
" top_p=top_p,\n", | |
" images_num=retries_per_level,\n", | |
" image_prompts=image_prompt,\n", | |
" temperature=temperature,\n", | |
" seed=random_seed\n", | |
" )\n", | |
" out_images += _pil_images\n", | |
" scores += _scores\n", | |
" aspect = (1, orig.size[1]/orig.size[0]) if orig.size[1]>orig.size[0] else (orig.size[0]/orig.size[1], 1) if orig.size[0]>orig.size[1] else (1,1)\n", | |
" if fix_aspect_ratio == 'before_clip' and (aspect[0]!=1 or aspect[1]!=1):\n", | |
" out_images = [im.resize((int(aspect[0]*im.size[0]), int(aspect[1]*im.size[0]))) for im in out_images]\n", | |
" rescaled_orig = None\n", | |
" if paste_original_part_over_generated.startswith('before_clip'):\n", | |
" if image.size == out_images[0].size: \n", | |
" rescaled_orig = image\n", | |
" elif orig.size == out_images[0].size:\n", | |
" rescaled_orig = orig\n", | |
" else:\n", | |
" rescaled_orig = orig.resize(out_images[0].size)\n", | |
" out_images = [crop(im, borders, mask=rescaled_orig) for im in out_images]\n", | |
" if text_to_use:\n", | |
" if flip and flip_result_back == 'before_clip':\n", | |
" out_images = [im.transpose(Image.FLIP_TOP_BOTTOM) for im in out_images]\n", | |
" out_images, _ = cherry_pick_by_ruclip(out_images, text_to_use, clip_predictor, count=None)\n", | |
" if flip and flip_result_back == 'before_clip':\n", | |
" out_images = [im.transpose(Image.FLIP_TOP_BOTTOM) for im in out_images]\n", | |
" else:\n", | |
" out_images, _ = zip(*sorted(zip(out_images, scores), key=lambda x: x[1]))\n", | |
" out_images = list(out_images)\n", | |
" if fix_aspect_ratio == 'after_clip' and (aspect[0]!=1 or aspect[1]!=1):\n", | |
" out_images = [im.resize((int(aspect[0]*im.size[0]), int(aspect[1]*im.size[0]))) for im in out_images]\n", | |
" if paste_original_part_over_generated == 'between_clip_and_super_resolution':\n", | |
" if rescaled_orig is None or rescaled_orig.size != out_images[0].size: \n", | |
" if image.size == out_images[0].size: \n", | |
" rescaled_orig = image\n", | |
" elif orig.size == out_images[0].size:\n", | |
" rescaled_orig = orig\n", | |
" else:\n", | |
" rescaled_orig = orig.resize(out_images[0].size) \n", | |
" out_images = [crop(im, borders, mask=rescaled_orig) for im in out_images]\n", | |
" if super_resolution_factor > 1:\n", | |
" out_images = super_resolution(out_images, realesrgan[super_resolution_factor])\n", | |
" if rescaled_orig is None or rescaled_orig.size != out_images[0].size: \n", | |
" if image.size == out_images[0].size: \n", | |
" rescaled_orig = image\n", | |
" elif orig.size == out_images[0].size:\n", | |
" rescaled_orig = orig\n", | |
" else:\n", | |
" rescaled_orig = orig.resize(out_images[0].size)\n", | |
" if paste_original_part_over_generated.endswith('after_super_resolution'):\n", | |
" out_images = [crop(im, borders, mask=rescaled_orig) for im in out_images] \n", | |
" out_images = [rescaled_orig, crop(rescaled_orig, borders)] + out_images\n", | |
" if flip and flip_result_back != 'no':\n", | |
" out_images = [im.transpose(Image.FLIP_TOP_BOTTOM) for im in out_images] \n", | |
" all_hires.append(out_images)\n", | |
" folders = [int(folder) for folder in os.listdir('/content/output') if os.path.isdir(os.path.join('/content/output', folder)) and folder.isnumeric()]\n", | |
" save_dir = os.path.join('/content/output', '%04d'%(max(folders, default=0)+1)) \n", | |
" save_dirs.append(save_dir)\n", | |
" x, y = out_images[0].size\n", | |
" out_images = out_images[:limit_display_results+2]\n", | |
" show(out_images, nrow, save_dir=save_dir, size=display_width * max(1, (len(out_images)//nrow) / nrow * y / x))\n", | |
" if notebook:\n", | |
" with open(os.path.join(save_dir, 'notebook.ipynb'), 'w', encoding='utf8') as f:\n", | |
" json.dump(notebook, f)\n", | |
"if save_dirs:\n", | |
" print('files saved to output folders:')\n", | |
" for save_dir in save_dirs:\n", | |
" print(save_dir)" | |
], | |
"id": "cebf4449" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "yULZADI6pGqH" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Download high-resolution images and notebook copy\n", | |
"if save_dirs:\n", | |
" save_file = os.path.join('/content', save_dirs[-1].rsplit('/',1)[-1] + '.zip') \n", | |
" if len(save_dirs)==1:\n", | |
" save_dirs_str = save_dirs[-1]\n", | |
" !zip -rjqFS $save_file $save_dirs_str\n", | |
" else:\n", | |
" save_dirs_str = ' '.join(save_dir.rsplit('/',1)[-1] for save_dir in save_dirs)\n", | |
" %pushd /content/output\n", | |
" !zip -rqFS $save_file $save_dirs_str\n", | |
" %popd\n", | |
" files.download(save_file)" | |
], | |
"id": "yULZADI6pGqH" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "1GJ90t3GSQ4T" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Display all images together\n", | |
"limit_display_results = 0#@param {type:'integer'}\n", | |
"display_width = 30#@param {type:'number'}\n", | |
"if limit_display_results < 1 or limit_display_results > num_levels * retries_per_level:\n", | |
" limit_display_results = num_levels * retries_per_level\n", | |
"nrow = int(np.ceil((limit_display_results+2)/np.ceil((limit_display_results+2)/6)))\n", | |
"for hires_images in all_hires:\n", | |
" x, y = hires_images[0].size\n", | |
" hires_images = hires_images[:limit_display_results+2]\n", | |
" show(hires_images, nrow, size=display_width * max(1, (len(hires_images)//nrow) / nrow * y / x))" | |
], | |
"id": "1GJ90t3GSQ4T" | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"collapsed_sections": [], | |
"machine_shape": "hm", | |
"name": "ruDALLE-Outpainting.ipynb", | |
"provenance": [], | |
"private_outputs": true, | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment