Created
July 25, 2020 22:13
-
-
Save brockmanmatt/7a346d641e2d2159eb3319f888193212 to your computer and use it in GitHub Desktop.
introToLogProbs.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "introToLogProbs.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyN0wJnZBPsjhHaevTzCp+o2", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/brockmanmatt/7a346d641e2d2159eb3319f888193212/introtologprobs.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "J7wnsgT2kPut", | |
"colab_type": "code", | |
"colab": { | |
"resources": { | |
"http://localhost:8080/nbextensions/google.colab/files.js": { | |
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK", | |
"ok": true, | |
"headers": [ | |
[ | |
"content-type", | |
"application/javascript" | |
] | |
], | |
"status": 200, | |
"status_text": "" | |
} | |
}, | |
"base_uri": "https://localhost:8080/", | |
"height": 89 | |
}, | |
"outputId": "98be6105-c5b7-4e07-ae4f-e02b1cda47d0" | |
}, | |
"source": [ | |
"from google.colab import files\n", | |
"uploaded = files.upload()\n", | |
"print(\"done\")" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <input type=\"file\" id=\"files-014046c1-ce74-4a65-bc7c-70c1a9929292\" name=\"files[]\" multiple disabled\n", | |
" style=\"border:none\" />\n", | |
" <output id=\"result-014046c1-ce74-4a65-bc7c-70c1a9929292\">\n", | |
" Upload widget is only available when the cell has been executed in the\n", | |
" current browser session. Please rerun this cell to enable.\n", | |
" </output>\n", | |
" <script src=\"/nbextensions/google.colab/files.js\"></script> " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Saving key.json to key.json\n", | |
"done\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "WHPHrUnhpKnI", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"I'll install the API" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zq0ltp2xn4yt", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 139 | |
}, | |
"outputId": "da8426d1-ec6f-4e4c-b02f-57016b1254d2" | |
}, | |
"source": [ | |
"!pip install openai\n", | |
"import openai, json, pandas as pd, numpy as np" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: openai in /usr/local/lib/python3.6/dist-packages (0.2.4)\n", | |
"Requirement already satisfied: requests>=2.20; python_version >= \"3.0\" in /usr/local/lib/python3.6/dist-packages (from openai) (2.23.0)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (2020.6.20)\n", | |
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (3.0.4)\n", | |
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (1.24.3)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (2.10)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Q2yE0jcnpMEV", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Loading in key.json that I uploaded; I do this so I don't need to worry about accidently leaking creds if I share the colab (which I'm 99% sure is just a json file that won't expose them)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "bwNXXwHen5x9", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"openai.api_key = json.load(open(\"key.json\", \"r\"))[\"key\"]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "k67w5H0fpTkT", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Default keyword arguments to pass the aPI" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "e1EwpqqJkTYh", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"#arguments to send the API\n", | |
"kwargs = {\n", | |
"\"engine\":\"davinci\",\n", | |
"\"temperature\":0,\n", | |
"\"max_tokens\":10,\n", | |
"\"stop\":\"\\n\",\n", | |
"}" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "zZubgPoOpWDH", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Quick wrapper to automatically save prompts and responses sent for later analysis if needed" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kY9t_siFKaPc", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"prompt = \"\"\"q: what is the capital of France\n", | |
"a:\"\"\"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WkLAGMCSKqqn", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"r = openai.Completion.create(prompt=prompt, **kwargs)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XoJucYblKvX4", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "4a796fde-4bc0-4b15-e1cd-15633619f157" | |
}, | |
"source": [ | |
"r[\"choices\"][0][\"text\"]" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
}, | |
"text/plain": [ | |
"' Paris'" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 21 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9Po7MjUJKzGJ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"kwargs[\"logprobs\"] = 5" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0teEX7n6K6qP", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"r = openai.Completion.create(prompt=prompt, **kwargs)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "aH9gsglBMfAi", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"So here's all the logprobs for the subsequent tokens; it hit the stop (\\n), generated a few moe followups but still stopped." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WN_UTtS-K7xg", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 297 | |
}, | |
"outputId": "025c7d14-4287-49b4-8d16-9921742e5d66" | |
}, | |
"source": [ | |
"pd.DataFrame(r[\"choices\"][0][\"logprobs\"])" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>tokens</th>\n", | |
" <th>token_logprobs</th>\n", | |
" <th>top_logprobs</th>\n", | |
" <th>text_offset</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>Paris</td>\n", | |
" <td>-0.828964</td>\n", | |
" <td>{' par': -1.6102142, ' Par': -4.235214, ' PAR'...</td>\n", | |
" <td>35</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>\\n</td>\n", | |
" <td>-0.364414</td>\n", | |
" <td>{',': -3.1456642, '.': -2.6144142, '\n", | |
"': -0.364...</td>\n", | |
" <td>41</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>q</td>\n", | |
" <td>-1.213570</td>\n", | |
" <td>{'\n", | |
"': -1.5885696, 'The': -4.2291946, 'b': -2.4...</td>\n", | |
" <td>41</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>:</td>\n", | |
" <td>-0.004189</td>\n", | |
" <td>{' :': -7.0354385, '.': -7.0354385, '1': -8.53...</td>\n", | |
" <td>41</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>what</td>\n", | |
" <td>-0.479179</td>\n", | |
" <td>{' What': -2.2916794, ' who': -3.4791794, ' wh...</td>\n", | |
" <td>41</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>is</td>\n", | |
" <td>-0.297340</td>\n", | |
" <td>{' country': -4.4223404, ' color': -4.0473404,...</td>\n", | |
" <td>41</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>the</td>\n", | |
" <td>-0.146500</td>\n", | |
" <td>{' a': -4.0527496, ' the': -0.14649963, ' 1': ...</td>\n", | |
" <td>41</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>capital</td>\n", | |
" <td>-0.774006</td>\n", | |
" <td>{' name': -3.586506, ' color': -3.867756, ' ca...</td>\n", | |
" <td>41</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" tokens ... text_offset\n", | |
"0 Paris ... 35\n", | |
"1 \\n ... 41\n", | |
"2 q ... 41\n", | |
"3 : ... 41\n", | |
"4 what ... 41\n", | |
"5 is ... 41\n", | |
"6 the ... 41\n", | |
"7 capital ... 41\n", | |
"\n", | |
"[8 rows x 4 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 31 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "le1fHX-BMtGj", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"we can look more at the possibilites it considered for paris, converting the logprobs to % by taking e**logprob\n", | |
"\n", | |
"Paris wins with 43%, although it almost went par" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "eGfnBWvHK8hS", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 204 | |
}, | |
"outputId": "cbf7be34-3ca6-46f2-d523-1237edbe4012" | |
}, | |
"source": [ | |
"scores = pd.DataFrame([r[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][0]]).T\n", | |
"scores.columns = [\"logprob\"]\n", | |
"scores[\"%\"] = scores[\"logprob\"].apply(lambda x: 100*np.e**x)\n", | |
"scores" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>logprob</th>\n", | |
" <th>%</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>par</th>\n", | |
" <td>-1.610214</td>\n", | |
" <td>19.984480</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>Par</th>\n", | |
" <td>-4.235214</td>\n", | |
" <td>1.447671</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>PAR</th>\n", | |
" <td>-4.172714</td>\n", | |
" <td>1.541038</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>Paris</th>\n", | |
" <td>-0.828964</td>\n", | |
" <td>43.650117</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>what</th>\n", | |
" <td>-4.422714</td>\n", | |
" <td>1.200162</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" logprob %\n", | |
" par -1.610214 19.984480\n", | |
" Par -4.235214 1.447671\n", | |
" PAR -4.172714 1.541038\n", | |
" Paris -0.828964 43.650117\n", | |
" what -4.422714 1.200162" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 45 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "UJKQ0T9gP9Ce", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"We can see if we increase the temperature, it takes non-optimal answers. However, it still tries to complete the task and eventually makes it back to Paris (although that's not guaranteed)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "6a4yfgKmP5H9", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"kwargs[\"temperature\"] = 1.2\n", | |
"r = openai.Completion.create(prompt=prompt, **kwargs)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5f0sTMWCP5NY", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 359 | |
}, | |
"outputId": "64cd64ce-80ae-4aad-9138-c34fbf6553a8" | |
}, | |
"source": [ | |
"pd.DataFrame(r[\"choices\"][0][\"logprobs\"])" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>tokens</th>\n", | |
" <th>token_logprobs</th>\n", | |
" <th>top_logprobs</th>\n", | |
" <th>text_offset</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>that</td>\n", | |
" <td>-5.997170</td>\n", | |
" <td>{' Paris': -0.8409195, ' par': -1.5284195, ' P...</td>\n", | |
" <td>35</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>'s</td>\n", | |
" <td>-0.899242</td>\n", | |
" <td>{' is': -1.1492424, ''s': -0.8992424, 'bytes:\\...</td>\n", | |
" <td>40</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>an</td>\n", | |
" <td>-3.084446</td>\n", | |
" <td>{' easy': -3.006321, ' a': -1.475071, ' an': -...</td>\n", | |
" <td>42</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>easy</td>\n", | |
" <td>-1.239227</td>\n", | |
" <td>{' example': -3.5361023, ' easy': -1.2392273, ...</td>\n", | |
" <td>45</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>one</td>\n", | |
" <td>-0.112442</td>\n", | |
" <td>{' q': -6.128067, ' question': -2.424942, ' an...</td>\n", | |
" <td>50</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>-</td>\n", | |
" <td>-4.639725</td>\n", | |
" <td>{',': -1.1084747, '.': -2.1397247, ':': -2.483...</td>\n", | |
" <td>54</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>Paris</td>\n", | |
" <td>-2.285274</td>\n", | |
" <td>{' par': -3.0977745, ' it': -2.3321495, 'Paris...</td>\n", | |
" <td>55</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>(</td>\n", | |
" <td>-4.684345</td>\n", | |
" <td>{'\n", | |
"': -0.52809525, '.': -2.0280952, '!': -2.24...</td>\n", | |
" <td>61</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>Y</td>\n", | |
" <td>-7.411562</td>\n", | |
" <td>{'or': -2.895937, 'the': -3.380312, 'correct':...</td>\n", | |
" <td>63</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>ay</td>\n", | |
" <td>-1.383808</td>\n", | |
" <td>{'ahoo': -3.2275581, 'ay': -1.3838081, 'AY': -...</td>\n", | |
" <td>64</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" tokens ... text_offset\n", | |
"0 that ... 35\n", | |
"1 's ... 40\n", | |
"2 an ... 42\n", | |
"3 easy ... 45\n", | |
"4 one ... 50\n", | |
"5 - ... 54\n", | |
"6 Paris ... 55\n", | |
"7 ( ... 61\n", | |
"8 Y ... 63\n", | |
"9 ay ... 64\n", | |
"\n", | |
"[10 rows x 4 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 87 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Z5CtVmbyP5Kl", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HMsKJCv1P5FS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kU5aacCKM1lS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"prompt = \"\"\"These word rhyme:\n", | |
"red:led\n", | |
"dog:frog\n", | |
"small:tall\n", | |
"train:\"\"\"\n", | |
"kwargs[\"logprobs\"] = 10\n", | |
"kwargs[\"max_tokens\"] = 20\n", | |
"kwargs[\"temperature\"] = 0\n", | |
"\n", | |
"r = openai.Completion.create(prompt=prompt, **kwargs)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kuewt1akNPRM", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 359 | |
}, | |
"outputId": "54ad0bfa-f8aa-4e9c-8b15-750e88998796" | |
}, | |
"source": [ | |
"scores = pd.DataFrame([r[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][0]]).T\n", | |
"scores.columns = [\"logprob\"]\n", | |
"scores[\"%\"] = scores[\"logprob\"].apply(lambda x: 100*np.e**x)\n", | |
"scores.sort_values(by=\"%\", ascending=False)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>logprob</th>\n", | |
" <th>%</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>pain</th>\n", | |
" <td>-1.277435</td>\n", | |
" <td>27.875130</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>rain</th>\n", | |
" <td>-2.277435</td>\n", | |
" <td>10.254687</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>brain</th>\n", | |
" <td>-2.621185</td>\n", | |
" <td>7.271662</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>chain</th>\n", | |
" <td>-3.277435</td>\n", | |
" <td>3.772489</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>str</th>\n", | |
" <td>-3.355560</td>\n", | |
" <td>3.488982</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>plane</th>\n", | |
" <td>-3.621185</td>\n", | |
" <td>2.675095</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>gain</th>\n", | |
" <td>-3.746185</td>\n", | |
" <td>2.360763</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>main</th>\n", | |
" <td>-3.933685</td>\n", | |
" <td>1.957141</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>p</th>\n", | |
" <td>-4.027435</td>\n", | |
" <td>1.781997</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>plant</th>\n", | |
" <td>-4.089935</td>\n", | |
" <td>1.674032</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" logprob %\n", | |
"pain -1.277435 27.875130\n", | |
"rain -2.277435 10.254687\n", | |
"brain -2.621185 7.271662\n", | |
"chain -3.277435 3.772489\n", | |
"str -3.355560 3.488982\n", | |
"plane -3.621185 2.675095\n", | |
"gain -3.746185 2.360763\n", | |
"main -3.933685 1.957141\n", | |
"p -4.027435 1.781997\n", | |
"plant -4.089935 1.674032" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 116 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-EoMNmusN1cw", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"prompt = \"\"\"These pairs of sentences rhyme:\n", | |
"My favorite color is red\n", | |
"ends with: \"red\"\n", | |
"\"red\" rhymes with \"bed\"\n", | |
"Rhyme: It's the color of my bed\n", | |
"-----\n", | |
"I once had a dog\n", | |
"ends with: \"dog\"\n", | |
"\"dog\" rhymes with \"frog\"\n", | |
"Rhyme: That good boy ate a frog\n", | |
"-----\n", | |
"I wish I was small\n", | |
"ends with: \"small\"\n", | |
"\"small\" rhymes with \"tall\"\n", | |
"Rhyme: Instead I'm so tall ='(\n", | |
"-----\n", | |
"That's a cool train\n", | |
"ends with:\"\"\"\n", | |
"kwargs[\"logprobs\"] = 5\n", | |
"kwargs[\"max_tokens\"] = 40\n", | |
"kwargs[\"temperature\"] = 0\n", | |
"kwargs[\"stop\"] = \"-----\"\n", | |
"\n", | |
"r = openai.Completion.create(prompt=prompt, **kwargs)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LIovt7zaS_eP", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "63f13251-d56d-4001-d54b-9cd6fb44dbd9" | |
}, | |
"source": [ | |
"r[\"choices\"][0][\"text\"]" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
}, | |
"text/plain": [ | |
"' \"train\"\\n\"train\" rhymes with \"rain\"\\nRhyme: I like to ride the rain\\n'" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 170 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pu_sgx2vSuWW", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 204 | |
}, | |
"outputId": "8348e9fc-1efe-43cc-cc96-d27a4f70a5e1" | |
}, | |
"source": [ | |
"scores = pd.DataFrame([r[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][11]]).T\n", | |
"scores.columns = [\"logprob\"]\n", | |
"scores[\"%\"] = scores[\"logprob\"].apply(lambda x: 100*np.e**x)\n", | |
"scores.sort_values(by=\"%\", ascending=False)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>logprob</th>\n", | |
" <th>%</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>rain</th>\n", | |
" <td>-1.411346</td>\n", | |
" <td>24.381479</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>pain</th>\n", | |
" <td>-2.051971</td>\n", | |
" <td>12.848137</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>brain</th>\n", | |
" <td>-2.567596</td>\n", | |
" <td>7.671973</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>plane</th>\n", | |
" <td>-3.458221</td>\n", | |
" <td>3.148571</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>str</th>\n", | |
" <td>-3.583221</td>\n", | |
" <td>2.778604</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" logprob %\n", | |
"rain -1.411346 24.381479\n", | |
"pain -2.051971 12.848137\n", | |
"brain -2.567596 7.671973\n", | |
"plane -3.458221 3.148571\n", | |
"str -3.583221 2.778604" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 171 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "URQ2ENPwSveY", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 483 | |
}, | |
"outputId": "3673773e-8673-4f79-ee59-68ae8b37681b" | |
}, | |
"source": [ | |
"pd.DataFrame(r[\"choices\"][0][\"logprobs\"])[18:]" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>tokens</th>\n", | |
" <th>token_logprobs</th>\n", | |
" <th>top_logprobs</th>\n", | |
" <th>text_offset</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>18</th>\n", | |
" <td>I</td>\n", | |
" <td>-1.807033</td>\n", | |
" <td>{' That': -1.9320335, ' I': -1.8070335, ' The'...</td>\n", | |
" <td>407</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>19</th>\n", | |
" <td>like</td>\n", | |
" <td>-1.590843</td>\n", | |
" <td>{''m': -2.6533432, ' love': -2.5908432, ' wish...</td>\n", | |
" <td>409</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>20</th>\n", | |
" <td>to</td>\n", | |
" <td>-1.032188</td>\n", | |
" <td>{' trains': -2.5634384, ' the': -2.0946884, ' ...</td>\n", | |
" <td>414</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>21</th>\n", | |
" <td>ride</td>\n", | |
" <td>-1.397640</td>\n", | |
" <td>{' watch': -1.5538902, ' hear': -3.6476402, ' ...</td>\n", | |
" <td>417</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>22</th>\n", | |
" <td>the</td>\n", | |
" <td>-1.274738</td>\n", | |
" <td>{' a': -2.3059883, ' the': -1.2747383, ' in': ...</td>\n", | |
" <td>422</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>23</th>\n", | |
" <td>rain</td>\n", | |
" <td>-0.667145</td>\n", | |
" <td>{' Rain': -5.68277, ' subway': -4.823395, ' \"'...</td>\n", | |
" <td>426</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>24</th>\n", | |
" <td>\\n</td>\n", | |
" <td>-0.295723</td>\n", | |
" <td>{'.': -3.389473, '\n", | |
"': -0.29572296, ' train': -...</td>\n", | |
" <td>431</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25</th>\n", | |
" <td>-----</td>\n", | |
" <td>-0.310501</td>\n", | |
" <td>{'\n", | |
"': -2.529251, '-----': -0.3105011, 'R': -4....</td>\n", | |
" <td>432</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>26</th>\n", | |
" <td>\\n</td>\n", | |
" <td>-0.019367</td>\n", | |
" <td>{'\n", | |
"': -0.019367218, ' ': -6.050617, ' I': -8.0...</td>\n", | |
" <td>432</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>27</th>\n", | |
" <td>I</td>\n", | |
" <td>-1.208935</td>\n", | |
" <td>{'\n", | |
"': -2.8651848, 'That': -3.0839348, 'My': -2...</td>\n", | |
" <td>432</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>28</th>\n", | |
" <td>like</td>\n", | |
" <td>-1.733200</td>\n", | |
" <td>{''m': -2.51445, ' love': -2.70195, ' have': -...</td>\n", | |
" <td>432</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29</th>\n", | |
" <td>to</td>\n", | |
" <td>-0.730225</td>\n", | |
" <td>{' the': -3.2614746, ' to': -0.7302246, ' that...</td>\n", | |
" <td>432</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>30</th>\n", | |
" <td>eat</td>\n", | |
" <td>-1.868992</td>\n", | |
" <td>{' read': -2.9002419, ' sing': -3.4002419, ' e...</td>\n", | |
" <td>432</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>31</th>\n", | |
" <td>\\n</td>\n", | |
" <td>-2.445057</td>\n", | |
" <td>{' pizza': -3.257557, '\n", | |
"': -2.445057, ' pie': ...</td>\n", | |
" <td>432</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" tokens ... text_offset\n", | |
"18 I ... 407\n", | |
"19 like ... 409\n", | |
"20 to ... 414\n", | |
"21 ride ... 417\n", | |
"22 the ... 422\n", | |
"23 rain ... 426\n", | |
"24 \\n ... 431\n", | |
"25 ----- ... 432\n", | |
"26 \\n ... 432\n", | |
"27 I ... 432\n", | |
"28 like ... 432\n", | |
"29 to ... 432\n", | |
"30 eat ... 432\n", | |
"31 \\n ... 432\n", | |
"\n", | |
"[14 rows x 4 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 172 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "BLRp1vNBTYI-", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 80 | |
}, | |
"outputId": "af7037bb-a3fc-49d0-fc53-93c7494b4a51" | |
}, | |
"source": [ | |
"pd.DataFrame([r[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][23]]).apply(lambda x: 100*np.e**x)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Rain</th>\n", | |
" <th>subway</th>\n", | |
" <th>\"</th>\n", | |
" <th>train</th>\n", | |
" <th>rain</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0.340412</td>\n", | |
" <td>0.803945</td>\n", | |
" <td>0.458074</td>\n", | |
" <td>42.543428</td>\n", | |
" <td>51.31717</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Rain subway \" train rain\n", | |
"0 0.340412 0.803945 0.458074 42.543428 51.31717" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 173 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TZGcB9AXYrgs", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"rhymed = pd.DataFrame(r[\"choices\"][0][\"logprobs\"])[18:]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "NdNI_F9wVEUG", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"What if we make this more creative" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wRxYDGQdUuMS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"prompt = \"\"\"These pairs of sentences rhyme:\n", | |
"My favorite color is red\n", | |
"ends with: \"red\"\n", | |
"\"red\" rhymes with \"bed\"\n", | |
"Rhyme: It's the color of my bed\n", | |
"-----\n", | |
"I once had a dog\n", | |
"ends with: \"dog\"\n", | |
"\"dog\" rhymes with \"frog\"\n", | |
"Rhyme: That good boy ate a frog\n", | |
"-----\n", | |
"I wish I was small\n", | |
"ends with: \"small\"\n", | |
"\"small\" rhymes with \"tall\"\n", | |
"Rhyme: Instead I'm so tall ='(\n", | |
"-----\n", | |
"That's a cool train\n", | |
"ends with:\"\"\"\n", | |
"kwargs[\"logprobs\"] = 5\n", | |
"kwargs[\"max_tokens\"] = 40\n", | |
"kwargs[\"temperature\"] = .5\n", | |
"kwargs[\"stop\"] = \"-----\"\n", | |
"\n", | |
"r = openai.Completion.create(prompt=prompt, **kwargs)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hStXV_UCVGjO", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "715b868a-132a-4959-8498-4b64dcc0ff50" | |
}, | |
"source": [ | |
"r[\"choices\"][0][\"text\"]" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
}, | |
"text/plain": [ | |
"' \"train\"\\n\"train\" rhymes with \"plane\"\\nRhyme: It\\'s not a plane\\n'" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 192 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "08OYB7FjVYw2", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"df = pd.DataFrame(r[\"choices\"][0][\"logprobs\"])[18:]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5udyaq7NVoCv", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def getTopValueFromDict(someDict):\n", | |
" myDict = dict(someDict)\n", | |
" vals = [(myDict[x], x) for x in myDict]\n", | |
" return max(vals)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Mn0atxbbVa34", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"df[\"actual_top_logprob\"] = df.top_logprobs.apply(lambda x: getTopValueFromDict(x))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XnKf1xSoVIPq", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 483 | |
}, | |
"outputId": "679f20bd-1467-4469-c366-1e553f53d3b4" | |
}, | |
"source": [ | |
"df" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>tokens</th>\n", | |
" <th>token_logprobs</th>\n", | |
" <th>top_logprobs</th>\n", | |
" <th>text_offset</th>\n", | |
" <th>actual_top_logprob</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>18</th>\n", | |
" <td>It</td>\n", | |
" <td>-2.047520</td>\n", | |
" <td>{' That': -1.9850197, ' I': -1.6725197, ' The'...</td>\n", | |
" <td>408</td>\n", | |
" <td>(-1.6725197, I)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>19</th>\n", | |
" <td>'s</td>\n", | |
" <td>-0.899273</td>\n", | |
" <td>{' flies': -2.649273, ' makes': -3.586773, ' g...</td>\n", | |
" <td>411</td>\n", | |
" <td>(-0.8992729, 's)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>20</th>\n", | |
" <td>not</td>\n", | |
" <td>-2.945133</td>\n", | |
" <td>{' a': -1.6326332, ' the': -2.4451332, ' not':...</td>\n", | |
" <td>413</td>\n", | |
" <td>(-1.6326332, a)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>21</th>\n", | |
" <td>a</td>\n", | |
" <td>-1.259831</td>\n", | |
" <td>{' a': -1.2598305, ' as': -2.4160805, ' like':...</td>\n", | |
" <td>417</td>\n", | |
" <td>(-1.2598305, a)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>22</th>\n", | |
" <td>plane</td>\n", | |
" <td>-0.845264</td>\n", | |
" <td>{' big': -3.0952644, ' cool': -3.1577644, ' pl...</td>\n", | |
" <td>419</td>\n", | |
" <td>(-0.84526443, plane)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>23</th>\n", | |
" <td>\\n</td>\n", | |
" <td>-1.518738</td>\n", | |
" <td>{',': -1.3312378, '\n", | |
"': -1.5187378, ' or': -1.7...</td>\n", | |
" <td>425</td>\n", | |
" <td>(-1.3312378, ,)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>24</th>\n", | |
" <td>-----</td>\n", | |
" <td>-0.515564</td>\n", | |
" <td>{'\"': -3.343689, 'It': -2.859314, '\n", | |
"': -3.4843...</td>\n", | |
" <td>426</td>\n", | |
" <td>(-0.51556396, -----)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25</th>\n", | |
" <td>\\n</td>\n", | |
" <td>-0.013977</td>\n", | |
" <td>{'\n", | |
"': -0.013977051, ' ': -6.263977, ' I': -7.9...</td>\n", | |
" <td>426</td>\n", | |
" <td>(-0.013977051, \\n)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>26</th>\n", | |
" <td>That</td>\n", | |
" <td>-2.864292</td>\n", | |
" <td>{'\n", | |
"': -3.2392921, 'That': -2.8642921, 'My': -2...</td>\n", | |
" <td>426</td>\n", | |
" <td>(-1.2080421, I)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>27</th>\n", | |
" <td>'s</td>\n", | |
" <td>-4.619221</td>\n", | |
" <td>{''s': -4.6192207, ' is': -5.1348457, ' was': ...</td>\n", | |
" <td>426</td>\n", | |
" <td>(-4.6192207, 's)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>28</th>\n", | |
" <td>a</td>\n", | |
" <td>-0.000057</td>\n", | |
" <td>{' a': -5.722046e-05, ' my': -10.562557, ' not...</td>\n", | |
" <td>426</td>\n", | |
" <td>(-5.722046e-05, a)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29</th>\n", | |
" <td>cool</td>\n", | |
" <td>-2.566330</td>\n", | |
" <td>{' cool': -2.56633, ' big': -2.738205, ' nice'...</td>\n", | |
" <td>426</td>\n", | |
" <td>(-2.56633, cool)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>30</th>\n", | |
" <td>plane</td>\n", | |
" <td>-0.089695</td>\n", | |
" <td>{' plane': -0.08969498, ' train': -2.589695, '...</td>\n", | |
" <td>426</td>\n", | |
" <td>(-0.08969498, plane)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>31</th>\n", | |
" <td>\\n</td>\n", | |
" <td>-0.940624</td>\n", | |
" <td>{'\n", | |
"': -0.94062424, '!': -1.5343742, '.': -1.53...</td>\n", | |
" <td>426</td>\n", | |
" <td>(-0.94062424, \\n)</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" tokens token_logprobs ... text_offset actual_top_logprob\n", | |
"18 It -2.047520 ... 408 (-1.6725197, I)\n", | |
"19 's -0.899273 ... 411 (-0.8992729, 's)\n", | |
"20 not -2.945133 ... 413 (-1.6326332, a)\n", | |
"21 a -1.259831 ... 417 (-1.2598305, a)\n", | |
"22 plane -0.845264 ... 419 (-0.84526443, plane)\n", | |
"23 \\n -1.518738 ... 425 (-1.3312378, ,)\n", | |
"24 ----- -0.515564 ... 426 (-0.51556396, -----)\n", | |
"25 \\n -0.013977 ... 426 (-0.013977051, \\n)\n", | |
"26 That -2.864292 ... 426 (-1.2080421, I)\n", | |
"27 's -4.619221 ... 426 (-4.6192207, 's)\n", | |
"28 a -0.000057 ... 426 (-5.722046e-05, a)\n", | |
"29 cool -2.566330 ... 426 (-2.56633, cool)\n", | |
"30 plane -0.089695 ... 426 (-0.08969498, plane)\n", | |
"31 \\n -0.940624 ... 426 (-0.94062424, \\n)\n", | |
"\n", | |
"[14 rows x 5 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 196 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZtOVl-C9ZN9m", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"rhyming_pt5 = df.copy()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "GSFVvNSSYKZ4", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"How does the logprobs for the bad compare to logprobs for good?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "OjRAY_8iVOCi", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "vjwcNkRsZUln", | |
"colab": {} | |
}, | |
"source": [ | |
"prompt = \"\"\"These pairs of sentences rhyme:\n", | |
"My favorite color is red\n", | |
"ends with: \"red\"\n", | |
"\"red\" rhymes with \"bed\"\n", | |
"Rhyme: It's the color of my bed\n", | |
"-----\n", | |
"I once had a dog\n", | |
"ends with: \"dog\"\n", | |
"\"dog\" rhymes with \"frog\"\n", | |
"Rhyme: That good boy ate a frog\n", | |
"-----\n", | |
"I wish I was small\n", | |
"ends with: \"small\"\n", | |
"\"small\" rhymes with \"tall\"\n", | |
"Rhyme: Instead I'm so tall ='(\n", | |
"-----\n", | |
"That's a cool train\n", | |
"ends with:\"\"\"\n", | |
"kwargs[\"logprobs\"] = 5\n", | |
"kwargs[\"max_tokens\"] = 40\n", | |
"kwargs[\"temperature\"] = .5\n", | |
"kwargs[\"stop\"] = \"-----\"\n", | |
"\n", | |
"r = openai.Completion.create(prompt=prompt, **kwargs)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "cKiTbWxRZUlr", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "697448e5-d856-4d65-84de-e9a599f7944d" | |
}, | |
"source": [ | |
"r[\"choices\"][0][\"text\"]" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
}, | |
"text/plain": [ | |
"' \"train\"\\n\"train\" rhymes with \"rain\"\\nRhyme: The rain is so cool\\n'" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 199 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "SiJNYKcIZUlu", | |
"colab": {} | |
}, | |
"source": [ | |
"df = pd.DataFrame(r[\"choices\"][0][\"logprobs\"])[18:]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "9P30yDlrZUlx", | |
"colab": {} | |
}, | |
"source": [ | |
"def getTopValueFromDict(someDict):\n", | |
" myDict = dict(someDict)\n", | |
" vals = [(myDict[x], x) for x in myDict]\n", | |
" return max(vals)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "mW_4_NmEZUlz", | |
"colab": {} | |
}, | |
"source": [ | |
"df[\"actual_top_logprob\"] = df.top_logprobs.apply(lambda x: getTopValueFromDict(x))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "aVIqqoYEZUl2", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 483 | |
}, | |
"outputId": "42e04913-1c2a-4f4a-bc47-d530682fca51" | |
}, | |
"source": [ | |
"df" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>tokens</th>\n", | |
" <th>token_logprobs</th>\n", | |
" <th>top_logprobs</th>\n", | |
" <th>text_offset</th>\n", | |
" <th>actual_top_logprob</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>18</th>\n", | |
" <td>The</td>\n", | |
" <td>-2.859199</td>\n", | |
" <td>{' That': -1.9529495, ' I': -1.7966995, ' The'...</td>\n", | |
" <td>407</td>\n", | |
" <td>(-1.7966995, I)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>19</th>\n", | |
" <td>rain</td>\n", | |
" <td>-0.811218</td>\n", | |
" <td>{' water': -4.4674683, ' cool': -4.1237183, ' ...</td>\n", | |
" <td>411</td>\n", | |
" <td>(-0.81121826, rain)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>20</th>\n", | |
" <td>is</td>\n", | |
" <td>-2.080967</td>\n", | |
" <td>{' came': -1.799717, ' makes': -2.799717, ' co...</td>\n", | |
" <td>416</td>\n", | |
" <td>(-1.799717, came)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>21</th>\n", | |
" <td>so</td>\n", | |
" <td>-1.509876</td>\n", | |
" <td>{' coming': -2.7911263, ' a': -2.6348763, ' co...</td>\n", | |
" <td>419</td>\n", | |
" <td>(-1.5098763, so)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>22</th>\n", | |
" <td>cool</td>\n", | |
" <td>-1.401932</td>\n", | |
" <td>{' fun': -2.4019318, ' cool': -1.4019318, ' tr...</td>\n", | |
" <td>422</td>\n", | |
" <td>(-1.4019318, cool)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>23</th>\n", | |
" <td>\\n</td>\n", | |
" <td>-0.392258</td>\n", | |
" <td>{'!': -3.4860077, ',': -3.7047577, '.': -3.517...</td>\n", | |
" <td>427</td>\n", | |
" <td>(-0.3922577, \\n)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>24</th>\n", | |
" <td>-----</td>\n", | |
" <td>-0.415947</td>\n", | |
" <td>{'\n", | |
"': -2.228447, '-----': -0.41594696, 'R': -4...</td>\n", | |
" <td>428</td>\n", | |
" <td>(-0.41594696, -----)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25</th>\n", | |
" <td>\\n</td>\n", | |
" <td>-0.016262</td>\n", | |
" <td>{'\n", | |
"': -0.016262054, ' ': -6.297512, ' I': -7.7...</td>\n", | |
" <td>428</td>\n", | |
" <td>(-0.016262054, \\n)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>26</th>\n", | |
" <td>I</td>\n", | |
" <td>-1.232170</td>\n", | |
" <td>{'\n", | |
"': -2.95092, 'That': -2.91967, 'My': -2.482...</td>\n", | |
" <td>428</td>\n", | |
" <td>(-1.2321701, I)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>27</th>\n", | |
" <td>like</td>\n", | |
" <td>-1.846847</td>\n", | |
" <td>{''m': -2.5030975, ' love': -2.5030975, ' have...</td>\n", | |
" <td>428</td>\n", | |
" <td>(-1.8468475, like)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>28</th>\n", | |
" <td>to</td>\n", | |
" <td>-1.022278</td>\n", | |
" <td>{' the': -2.9597778, ' to': -1.0222778, ' that...</td>\n", | |
" <td>428</td>\n", | |
" <td>(-1.0222778, to)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29</th>\n", | |
" <td>eat</td>\n", | |
" <td>-1.644001</td>\n", | |
" <td>{' read': -2.925251, ' sing': -3.300251, ' eat...</td>\n", | |
" <td>428</td>\n", | |
" <td>(-1.644001, eat)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>30</th>\n", | |
" <td>rice</td>\n", | |
" <td>-3.705196</td>\n", | |
" <td>{'\n", | |
"': -2.8301964, ' ice': -2.8614464, ' pizza'...</td>\n", | |
" <td>428</td>\n", | |
" <td>(-2.8301964, \\n)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>31</th>\n", | |
" <td>\\n</td>\n", | |
" <td>-0.055901</td>\n", | |
" <td>{',': -5.5871506, '.': -4.6809006, '\n", | |
"': -0.055...</td>\n", | |
" <td>428</td>\n", | |
" <td>(-0.055900574, \\n)</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" tokens token_logprobs ... text_offset actual_top_logprob\n", | |
"18 The -2.859199 ... 407 (-1.7966995, I)\n", | |
"19 rain -0.811218 ... 411 (-0.81121826, rain)\n", | |
"20 is -2.080967 ... 416 (-1.799717, came)\n", | |
"21 so -1.509876 ... 419 (-1.5098763, so)\n", | |
"22 cool -1.401932 ... 422 (-1.4019318, cool)\n", | |
"23 \\n -0.392258 ... 427 (-0.3922577, \\n)\n", | |
"24 ----- -0.415947 ... 428 (-0.41594696, -----)\n", | |
"25 \\n -0.016262 ... 428 (-0.016262054, \\n)\n", | |
"26 I -1.232170 ... 428 (-1.2321701, I)\n", | |
"27 like -1.846847 ... 428 (-1.8468475, like)\n", | |
"28 to -1.022278 ... 428 (-1.0222778, to)\n", | |
"29 eat -1.644001 ... 428 (-1.644001, eat)\n", | |
"30 rice -3.705196 ... 428 (-2.8301964, \\n)\n", | |
"31 \\n -0.055901 ... 428 (-0.055900574, \\n)\n", | |
"\n", | |
"[14 rows x 5 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 203 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "b2ZItSGmVRU7", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"bad_pt5 = df.copy()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "3DUTQZDpcC1f", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"K, real quick, how do the average logprobs compare? The highest logprob average rhymes! So this is a good indication that an average high logprob will be the correct answer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "38Ghx9zpcadg", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "d18ef09d-a8f9-467e-f5a2-0b5cbe21a2b6" | |
}, | |
"source": [ | |
"rhymed[:rhymed.tokens.to_list().index(\"\\n\")].token_logprobs.mean()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"-1.2949314" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 213 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "6fo2G7SeZau2", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "a2babeaf-11b4-4f56-a48a-12b7c8ce7b6b" | |
}, | |
"source": [ | |
"rhyming_pt5[:rhyming_pt5.tokens.to_list().index(\"\\n\")].token_logprobs.mean()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"-1.5994041460000001" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 211 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "cTbGva6hcIKv", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "ed0fc4cd-c915-4d87-c78a-a7a8c7b6a8d8" | |
}, | |
"source": [ | |
"bad_pt5[:bad_pt5.tokens.to_list().index(\"\\n\")].token_logprobs.mean()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"-1.7326385720000002" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 212 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "k5NDvpIGrQnY", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Cool! This actually how best_of works; for instance, let's get n=10 at temp=.5" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HazROq7GcZ0_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"prompt = \"\"\"These pairs of sentences rhyme:\n", | |
"My favorite color is red\n", | |
"ends with: \"red\"\n", | |
"\"red\" rhymes with \"bed\"\n", | |
"Rhyme: It's the color of my bed\n", | |
"-----\n", | |
"I once had a dog\n", | |
"ends with: \"dog\"\n", | |
"\"dog\" rhymes with \"frog\"\n", | |
"Rhyme: That good boy ate a frog\n", | |
"-----\n", | |
"I wish I was small\n", | |
"ends with: \"small\"\n", | |
"\"small\" rhymes with \"tall\"\n", | |
"Rhyme: Instead I'm so tall ='(\n", | |
"-----\n", | |
"That's a cool train\n", | |
"ends with:\"\"\"\n", | |
"kwargs[\"logprobs\"] = 5\n", | |
"kwargs[\"max_tokens\"] = 40\n", | |
"kwargs[\"temperature\"] = .5\n", | |
"kwargs[\"stop\"] = \"-----\"\n", | |
"kwargs[\"n\"] = 10\n", | |
"\n", | |
"r = openai.Completion.create(prompt=prompt, **kwargs)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4_9zp9qlrZwC", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"texts = [r[\"choices\"][i][\"text\"].split(\"\\n\")[-2][7:] for i in range(10)]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dRk6wSS3r8x1", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"logprobs = []\n", | |
"for i in range(10):\n", | |
" df = pd.DataFrame(r[\"choices\"][i][\"logprobs\"])[18:]\n", | |
" df[\"actual_top_logprob\"] = df.top_logprobs.apply(lambda x: getTopValueFromDict(x))\n", | |
" logprobs.append(df[:df.tokens.to_list().index(\"\\n\")].token_logprobs.mean())" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TJviL5vXri9-", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"df = pd.DataFrame([texts]).T\n", | |
"df.columns=[\"text\"]\n", | |
"df[\"logprob\"] = logprobs\n", | |
"df[\"%\"] = df.logprob.apply(lambda x: 100*np.e**x)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-vaHwe_1rkTQ", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 359 | |
}, | |
"outputId": "3b00ce51-686d-44eb-aade-497f7c2fb4cf" | |
}, | |
"source": [ | |
"df.sort_values(by=\"logprob\", ascending=False)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>text</th>\n", | |
" <th>logprob</th>\n", | |
" <th>%</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>I like the rain</td>\n", | |
" <td>-1.092228</td>\n", | |
" <td>33.546824</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>That's a cool raincoat</td>\n", | |
" <td>-1.343981</td>\n", | |
" <td>26.080522</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>It's really fun to ride</td>\n", | |
" <td>-1.435801</td>\n", | |
" <td>23.792467</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>It goes \"Chugga chugga chugga\"</td>\n", | |
" <td>-1.829431</td>\n", | |
" <td>16.050481</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>I can't find my brain</td>\n", | |
" <td>-1.907298</td>\n", | |
" <td>14.848097</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>It's a very long train</td>\n", | |
" <td>-3.161690</td>\n", | |
" <td>4.235411</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>That's the brain train</td>\n", | |
" <td>-3.285344</td>\n", | |
" <td>3.742771</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>It's made of tin and rain</td>\n", | |
" <td>-3.425831</td>\n", | |
" <td>3.252225</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>I'm getting wet again</td>\n", | |
" <td>-4.019777</td>\n", | |
" <td>1.795697</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>The strain of that train</td>\n", | |
" <td>-4.875232</td>\n", | |
" <td>0.763333</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" text logprob %\n", | |
"5 I like the rain -1.092228 33.546824\n", | |
"9 That's a cool raincoat -1.343981 26.080522\n", | |
"6 It's really fun to ride -1.435801 23.792467\n", | |
"2 It goes \"Chugga chugga chugga\" -1.829431 16.050481\n", | |
"7 I can't find my brain -1.907298 14.848097\n", | |
"0 It's a very long train -3.161690 4.235411\n", | |
"8 That's the brain train -3.285344 3.742771\n", | |
"3 It's made of tin and rain -3.425831 3.252225\n", | |
"4 I'm getting wet again -4.019777 1.795697\n", | |
"1 The strain of that train -4.875232 0.763333" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 242 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "4L5-zOSus9qt", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"So the problem now is that the average logprob isn't even the best! We'll skin that cat later, but for now, what if we also get rid of the repitition. This might not work because we repeat but we'll penalize repetition just a tad" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rafy5RJGrqRi", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"kwargs[\"logprobs\"] = 1\n", | |
"kwargs[\"max_tokens\"] = 40\n", | |
"kwargs[\"temperature\"] = .5\n", | |
"kwargs[\"stop\"] = \"-----\"\n", | |
"kwargs[\"n\"] = 5\n", | |
"kwargs[\"frequency_penalty\"] = .1\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab_type": "code", | |
"id": "IIcHlTH6tild", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 204 | |
}, | |
"outputId": "3166ea6b-f2bf-45e2-e0fd-c896ed0c603e" | |
}, | |
"source": [ | |
"rslts = []\n", | |
"for frequency_penalty in range(-5, 6):\n", | |
" print(frequency_penalty)\n", | |
" kwargs[\"frequency_penalty\"] = np.round(.1 * frequency_penalty, 1)\n", | |
" r = openai.Completion.create(prompt=prompt, **kwargs)\n", | |
" texts = [r[\"choices\"][i][\"text\"].split(\"\\n\")[-2][7:] for i in range(5)]\n", | |
" logprobs = []\n", | |
" for i in range(5):\n", | |
" df = pd.DataFrame(r[\"choices\"][i][\"logprobs\"])[18:]\n", | |
" df[\"actual_top_logprob\"] = df.top_logprobs.apply(lambda x: getTopValueFromDict(x))\n", | |
" logprobs.append(df[:df.tokens.to_list().index(\"\\n\")].token_logprobs.mean())\n", | |
" df = pd.DataFrame([texts]).T\n", | |
" df.columns=[\"text\"]\n", | |
" df[\"logprob\"] = logprobs\n", | |
" df[\"%\"] = df.logprob.apply(lambda x: 100*np.e**x)\n", | |
" df.sort_values(by=\"logprob\", ascending=False)\n", | |
" df[\"frequency_penalty\"] = np.round(.1 * frequency_penalty, 1)\n", | |
" rslts.append(df.copy())" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"-5\n", | |
"-4\n", | |
"-3\n", | |
"-2\n", | |
"-1\n", | |
"0\n", | |
"1\n", | |
"2\n", | |
"3\n", | |
"4\n", | |
"5\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "yTdEYFUtyJq8", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Now, one of the big things we should realize is that changing the penalty likely influences the absolute value of the logprobs; \"that' a cool rain\" has basically the same logprob at .3 for some reason, but it drops off significantly at -.3 and .5." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "p9z6KxubxTm8", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "150a8056-a1c5-4842-f178-0f75f8294a62" | |
}, | |
"source": [ | |
"pd.concat(rslts).sort_values(by=\"%\", ascending=False)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>text</th>\n", | |
" <th>logprob</th>\n", | |
" <th>%</th>\n", | |
" <th>frequency_penalty</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>That's a cool train</td>\n", | |
" <td>-0.304744</td>\n", | |
" <td>73.731213</td>\n", | |
" <td>-0.4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td></td>\n", | |
" <td>-0.618064</td>\n", | |
" <td>53.898677</td>\n", | |
" <td>-0.4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>That train is so pain ='(</td>\n", | |
" <td>-0.764359</td>\n", | |
" <td>46.563253</td>\n", | |
" <td>0.1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>That's a cool plane</td>\n", | |
" <td>-0.946062</td>\n", | |
" <td>38.826682</td>\n", | |
" <td>0.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>That's a cool rain</td>\n", | |
" <td>-0.949751</td>\n", | |
" <td>38.683722</td>\n", | |
" <td>0.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>That's a cool rain</td>\n", | |
" <td>-0.949796</td>\n", | |
" <td>38.682010</td>\n", | |
" <td>0.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>That train was a strain</td>\n", | |
" <td>-1.077294</td>\n", | |
" <td>34.051571</td>\n", | |
" <td>0.1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>It makes a lot of noise</td>\n", | |
" <td>-1.172296</td>\n", | |
" <td>30.965518</td>\n", | |
" <td>-0.4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>My train has a strain</td>\n", | |
" <td>-1.250902</td>\n", | |
" <td>28.624662</td>\n", | |
" <td>0.1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>It's so long and it's so tall</td>\n", | |
" <td>-1.306161</td>\n", | |
" <td>27.085801</td>\n", | |
" <td>-0.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>It's raining outside</td>\n", | |
" <td>-1.312717</td>\n", | |
" <td>26.908783</td>\n", | |
" <td>-0.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>That's a cool rain</td>\n", | |
" <td>-1.356037</td>\n", | |
" <td>25.767990</td>\n", | |
" <td>-0.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>That's a painful train</td>\n", | |
" <td>-1.401003</td>\n", | |
" <td>24.634987</td>\n", | |
" <td>0.1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>I want to be a train driver</td>\n", | |
" <td>-1.417063</td>\n", | |
" <td>24.242492</td>\n", | |
" <td>0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>It's raining outside</td>\n", | |
" <td>-1.427773</td>\n", | |
" <td>23.984234</td>\n", | |
" <td>0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>That train was a pain</td>\n", | |
" <td>-1.431793</td>\n", | |
" <td>23.888017</td>\n", | |
" <td>0.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>That's a train of pain</td>\n", | |
" <td>-1.437199</td>\n", | |
" <td>23.759227</td>\n", | |
" <td>-0.1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>That's a cool rain</td>\n", | |
" <td>-1.451405</td>\n", | |
" <td>23.424087</td>\n", | |
" <td>0.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>It's so cool it's a plane</td>\n", | |
" <td>-1.453801</td>\n", | |
" <td>23.368034</td>\n", | |
" <td>0.4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>I got hit by a train ='(</td>\n", | |
" <td>-1.465098</td>\n", | |
" <td>23.105533</td>\n", | |
" <td>0.4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td></td>\n", | |
" <td>-1.488245</td>\n", | |
" <td>22.576853</td>\n", | |
" <td>-0.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>But I don't like pain ='(</td>\n", | |
" <td>-1.499659</td>\n", | |
" <td>22.320621</td>\n", | |
" <td>0.4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>That train flies through the air</td>\n", | |
" <td>-1.502256</td>\n", | |
" <td>22.262740</td>\n", | |
" <td>0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>That's a really cool train</td>\n", | |
" <td>-1.514158</td>\n", | |
" <td>21.999342</td>\n", | |
" <td>0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>I can't find my train</td>\n", | |
" <td>-1.523579</td>\n", | |
" <td>21.793059</td>\n", | |
" <td>-0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>That train made a big rain ='(</td>\n", | |
" <td>-1.538598</td>\n", | |
" <td>21.468188</td>\n", | |
" <td>0.4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>It makes me feel pain</td>\n", | |
" <td>-1.558070</td>\n", | |
" <td>21.054195</td>\n", | |
" <td>-0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>The train is a pain</td>\n", | |
" <td>-1.633714</td>\n", | |
" <td>19.520319</td>\n", | |
" <td>-0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>It goes on and on and on</td>\n", | |
" <td>-1.637339</td>\n", | |
" <td>19.449682</td>\n", | |
" <td>-0.1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>That's a cool pain</td>\n", | |
" <td>-1.637495</td>\n", | |
" <td>19.446663</td>\n", | |
" <td>-0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>But it's a pain to clean up the tracks</td>\n", | |
" <td>-1.639166</td>\n", | |
" <td>19.414195</td>\n", | |
" <td>0.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>That train's a pain</td>\n", | |
" <td>-1.647220</td>\n", | |
" <td>19.258458</td>\n", | |
" <td>-0.1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>That train goes through the rain</td>\n", | |
" <td>-1.686775</td>\n", | |
" <td>18.511552</td>\n", | |
" <td>0.1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td></td>\n", | |
" <td>-1.691131</td>\n", | |
" <td>18.431098</td>\n", | |
" <td>-0.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>It rains on the train</td>\n", | |
" <td>-1.793993</td>\n", | |
" <td>16.629476</td>\n", | |
" <td>-0.1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>My brain's on a train</td>\n", | |
" <td>-1.799681</td>\n", | |
" <td>16.535157</td>\n", | |
" <td>-0.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>I'm gonna fly on that plane</td>\n", | |
" <td>-1.814585</td>\n", | |
" <td>16.290550</td>\n", | |
" <td>-0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>It's the rain that made it cool</td>\n", | |
" <td>-1.823289</td>\n", | |
" <td>16.149373</td>\n", | |
" <td>-0.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>It goes on and on</td>\n", | |
" <td>-1.829408</td>\n", | |
" <td>16.050864</td>\n", | |
" <td>0.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>That's a big brain</td>\n", | |
" <td>-1.866597</td>\n", | |
" <td>15.464904</td>\n", | |
" <td>-0.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>That's a great rain</td>\n", | |
" <td>-1.927979</td>\n", | |
" <td>14.544180</td>\n", | |
" <td>0.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>It's a big rain</td>\n", | |
" <td>-1.937891</td>\n", | |
" <td>14.400728</td>\n", | |
" <td>-0.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>I love to watch it rain</td>\n", | |
" <td>-1.984703</td>\n", | |
" <td>13.742141</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>That train runs on pain</td>\n", | |
" <td>-1.992455</td>\n", | |
" <td>13.636021</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>The train is so fast</td>\n", | |
" <td>-1.997649</td>\n", | |
" <td>13.565388</td>\n", | |
" <td>-0.4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>The rain is so fun</td>\n", | |
" <td>-2.000404</td>\n", | |
" <td>13.528068</td>\n", | |
" <td>-0.1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>That's the brain train</td>\n", | |
" <td>-2.005497</td>\n", | |
" <td>13.459339</td>\n", | |
" <td>0.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>I want to ride that train</td>\n", | |
" <td>-2.009030</td>\n", | |
" <td>13.411870</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>It goes up and down the rain</td>\n", | |
" <td>-2.016399</td>\n", | |
" <td>13.313404</td>\n", | |
" <td>0.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>I had a train pain</td>\n", | |
" <td>-2.069944</td>\n", | |
" <td>12.619285</td>\n", | |
" <td>0.4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>It is so cool to play with pain</td>\n", | |
" <td>-2.074948</td>\n", | |
" <td>12.556298</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>That train's so awesome</td>\n", | |
" <td>-2.165770</td>\n", | |
" <td>11.466161</td>\n", | |
" <td>-0.4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>That train was so long</td>\n", | |
" <td>-2.189222</td>\n", | |
" <td>11.200386</td>\n", | |
" <td>-0.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>It makes me feel happy again</td>\n", | |
" <td>-2.263620</td>\n", | |
" <td>10.397338</td>\n", | |
" <td>0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>It makes me feel fine</td>\n", | |
" <td>-2.304157</td>\n", | |
" <td>9.984290</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" text ... frequency_penalty\n", | |
"3 That's a cool train ... -0.4\n", | |
"0 ... -0.4\n", | |
"3 That train is so pain ='( ... 0.1\n", | |
"1 That's a cool plane ... 0.3\n", | |
"0 That's a cool rain ... 0.3\n", | |
"4 That's a cool rain ... 0.3\n", | |
"2 That train was a strain ... 0.1\n", | |
"2 It makes a lot of noise ... -0.4\n", | |
"4 My train has a strain ... 0.1\n", | |
"1 It's so long and it's so tall ... -0.5\n", | |
"0 It's raining outside ... -0.3\n", | |
"3 That's a cool rain ... -0.3\n", | |
"0 That's a painful train ... 0.1\n", | |
"2 I want to be a train driver ... 0.2\n", | |
"3 It's raining outside ... 0.2\n", | |
"2 That train was a pain ... 0.3\n", | |
"3 That's a train of pain ... -0.1\n", | |
"3 That's a cool rain ... 0.5\n", | |
"0 It's so cool it's a plane ... 0.4\n", | |
"3 I got hit by a train ='( ... 0.4\n", | |
"1 ... -0.3\n", | |
"4 But I don't like pain ='( ... 0.4\n", | |
"0 That train flies through the air ... 0.2\n", | |
"1 That's a really cool train ... 0.2\n", | |
"0 I can't find my train ... -0.2\n", | |
"2 That train made a big rain ='( ... 0.4\n", | |
"3 It makes me feel pain ... -0.2\n", | |
"4 The train is a pain ... -0.2\n", | |
"1 It goes on and on and on ... -0.1\n", | |
"2 That's a cool pain ... -0.2\n", | |
"0 But it's a pain to clean up the tracks ... 0.5\n", | |
"2 That train's a pain ... -0.1\n", | |
"1 That train goes through the rain ... 0.1\n", | |
"2 ... -0.5\n", | |
"4 It rains on the train ... -0.1\n", | |
"4 My brain's on a train ... -0.5\n", | |
"1 I'm gonna fly on that plane ... -0.2\n", | |
"4 It's the rain that made it cool ... -0.3\n", | |
"3 It goes on and on ... 0.3\n", | |
"3 That's a big brain ... -0.5\n", | |
"1 That's a great rain ... 0.5\n", | |
"2 It's a big rain ... -0.3\n", | |
"3 I love to watch it rain ... 0.0\n", | |
"1 That train runs on pain ... 0.0\n", | |
"1 The train is so fast ... -0.4\n", | |
"0 The rain is so fun ... -0.1\n", | |
"4 That's the brain train ... 0.5\n", | |
"2 I want to ride that train ... 0.0\n", | |
"2 It goes up and down the rain ... 0.5\n", | |
"1 I had a train pain ... 0.4\n", | |
"0 It is so cool to play with pain ... 0.0\n", | |
"4 That train's so awesome ... -0.4\n", | |
"0 That train was so long ... -0.5\n", | |
"4 It makes me feel happy again ... 0.2\n", | |
"4 It makes me feel fine ... 0.0\n", | |
"\n", | |
"[55 rows x 4 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 272 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment