Skip to content

Instantly share code, notes, and snippets.

@sgbaird
Last active May 7, 2025 14:45
Show Gist options
  • Save sgbaird/78fbb50753c1089f487152817779fd74 to your computer and use it in GitHub Desktop.
Save sgbaird/78fbb50753c1089f487152817779fd74 to your computer and use it in GitHub Desktop.
hf-crabnet-hyperparameter.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMd9ZHgJaXAMV1ysyAhR3dj",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/sgbaird/78fbb50753c1089f487152817779fd74/hf-crabnet-hyperparameter.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Turing Test Optimization Benchmark\n",
"\n",
"Programmatic control of an advanced optimization task [hosted on Hugging Face Spaces](https://huggingface.co/spaces/AccelerationConsortium/crabnet-hyperparameter) based on the CrabNet hyperparameter tuning task. A high-level explanation of the benchmark is given at the Hugging Face link above. Details about the original benchmarking task can be found in the following manuscript:\n",
"\n",
"- https://doi.org/10.1016/j.commatsci.2022.111505\n",
"\n",
"Details about the creation of the Turing test benchmark are in the following manuscript:\n",
"\n",
"- https://doi.org/10.26434/chemrxiv-2023-9s6r7"
],
"metadata": {
"id": "gQgCoCvFQ5eX"
}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SUhzgYV7kAKa",
"outputId": "f89bc04c-9dad-477d-c111-41e0fcf0af60"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: gradio_client in /usr/local/lib/python3.11/dist-packages (1.10.0)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from gradio_client) (2025.3.2)\n",
"Requirement already satisfied: httpx>=0.24.1 in /usr/local/lib/python3.11/dist-packages (from gradio_client) (0.28.1)\n",
"Requirement already satisfied: huggingface-hub>=0.19.3 in /usr/local/lib/python3.11/dist-packages (from gradio_client) (0.30.2)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from gradio_client) (24.2)\n",
"Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.11/dist-packages (from gradio_client) (4.13.2)\n",
"Requirement already satisfied: websockets<16.0,>=10.0 in /usr/local/lib/python3.11/dist-packages (from gradio_client) (15.0.1)\n",
"Requirement already satisfied: anyio in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio_client) (4.9.0)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio_client) (2025.4.26)\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio_client) (1.0.9)\n",
"Requirement already satisfied: idna in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio_client) (3.10)\n",
"Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.11/dist-packages (from httpcore==1.*->httpx>=0.24.1->gradio_client) (0.16.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.19.3->gradio_client) (3.18.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.19.3->gradio_client) (6.0.2)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.19.3->gradio_client) (2.32.3)\n",
"Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.19.3->gradio_client) (4.67.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio->httpx>=0.24.1->gradio_client) (1.3.1)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.19.3->gradio_client) (3.4.1)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.19.3->gradio_client) (2.4.0)\n"
]
}
],
"source": [
"%pip install gradio_client"
]
},
{
"cell_type": "code",
"source": [
"from gradio_client import Client\n",
"\n",
"client = Client(\"AccelerationConsortium/crabnet-hyperparameter\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6ElcM3x6MNfK",
"outputId": "4114fba2-e85e-41f4-8982-6c3affa69cf2"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Loaded as API: https://accelerationconsortium-crabnet-hyperparameter.hf.space ✔\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"result = client.predict(\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x1' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x2' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x3' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0)in 'x4' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x5' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x6' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x7' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x8' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x9' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x10' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x11' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0000000000000002) in 'x12' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x13' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x14' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x15' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x16' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x17' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x18' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 0.9999999999999998) in 'x19' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 0.9999999999999998) in 'x20' Slider component\n",
"\t\t\"c1_0\",\t# Literal['c1_0', 'c1_1'] in 'c1' Radio component\n",
"\t\t\"c2_0\",\t# Literal['c2_0', 'c2_1'] in 'c2' Radio component\n",
"\t\t\"c3_0\",\t# Literal['c3_0', 'c3_1', 'c3_2'] in 'c3' Radio component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'fidelity1' Slider component\n",
"\t\tapi_name=\"/predict\"\n",
")\n",
"print(result['data'][0])\n",
"\n",
"# return type:\n",
"# Dict(headers: List[str], data: List[List[Any]], metadata: Dict(str, List[Any] | None) | None) representing output in 'output' Dataframe component"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "njRIl4GoMQmH",
"outputId": "b6b41f4d-b002-4b05-a6ae-17b8813fed0f"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[0.9771932187381239, 1.4907550714587514, 17.614516736284173, 2001617.8575317992]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"result = client.predict(\n",
"\t\t0.2222222222222222,\t# float (numeric value between 0.0 and 1.0) in 'x1' Slider component\n",
"\t\t0.5,\t# float (numeric value between 0.0 and 1.0) in 'x2' Slider component\n",
"\t\t0.4458874458874459,\t# float (numeric value between 0.0 and 1.0) in 'x3' Slider component\n",
"\t\t0.3333333333333333,\t# float (numeric value between 0.0 and 1.0)in 'x4' Slider component\n",
"\t\t0.1,\t# float (numeric value between 0.0 and 1.0) in 'x5' Slider component\n",
"\t\t0.5,\t# float (numeric value between 0.0 and 1.0) in 'x6' Slider component\n",
"\t\t0.3333333333333333,\t# float (numeric value between 0.0 and 1.0) in 'x7' Slider component\n",
"\t\t0.009009009009009009,\t# float (numeric value between 0.0 and 1.0) in 'x8' Slider component\n",
"\t\t0.2,\t# float (numeric value between 0.0 and 1.0) in 'x9' Slider component\n",
"\t\t0.3333333333333333,\t# float (numeric value between 0.0 and 1.0) in 'x10' Slider component\n",
"\t\t0.5,\t# float (numeric value between 0.0 and 1.0) in 'x11' Slider component\n",
"\t\t0.15254237288135594,\t# float (numeric value between 0.0 and 1.0000000000000002) in 'x12' Slider component\n",
"\t\t0.33333333333333337,\t# float (numeric value between 0.0 and 1.0) in 'x13' Slider component\n",
"\t\t0.33333333333333337,\t# float (numeric value between 0.0 and 1.0) in 'x14' Slider component\n",
"\t\t0.5,\t# float (numeric value between 0.0 and 1.0) in 'x15' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x16' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x17' Slider component\n",
"\t\t0.2,\t# float (numeric value between 0.0 and 1.0) in 'x18' Slider component\n",
"\t\t0.8001600320064013,\t# float (numeric value between 0.0 and 0.9999999999999998) in 'x19' Slider component\n",
"\t\t0.9981996399279855,\t# float (numeric value between 0.0 and 0.9999999999999998) in 'x20' Slider component\n",
"\t\t\"c1_0\",\t# Literal['c1_0', 'c1_1'] in 'c1' Radio component\n",
"\t\t\"c2_0\",\t# Literal['c2_0', 'c2_1'] in 'c2' Radio component\n",
"\t\t\"c3_0\",\t# Literal['c3_0', 'c3_1', 'c3_2'] in 'c3' Radio component\n",
"\t\t0.494949494949495,\t# float (numeric value between 0.0 and 1.0) in 'fidelity1' Slider component\n",
"\t\tapi_name=\"/predict\"\n",
")\n",
"print(result['data'][0])"
],
"metadata": {
"id": "g9rFFxukqq_A"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"To avoid rate limits, duplicate the Advanced Optimization from Acceleration Consortium's Hugging Face space for private use, as shown below. You will need to pass a Hugging Face token for your personal account, generated at https://huggingface.co/settings/tokens and set to allow for `write` access. See [security tokens](https://huggingface.co/docs/hub/en/security-tokens) for more info.\n",
"\n",
"You can add your token as a Colab secret by pressing the \"key\" icon on the sidebar panel. Enter `HF_TOKEN` as the name, and your token as the value (e.g., `hf_a1b2c3`)."
],
"metadata": {
"id": "2oaqbUAeNADZ"
}
},
{
"cell_type": "code",
"source": [
"del client"
],
"metadata": {
"id": "cRPi1hINn5rS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# load the Advanced Optimization from AC huggingface\n",
"from gradio_client import Client\n",
"\n",
"from google.colab import userdata\n",
"HF_TOKEN = userdata.get('HF_TOKEN')\n",
"\n",
"client = Client.duplicate(\"AccelerationConsortium/crabnet-hyperparameter\", hf_token=HF_TOKEN)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "56ygptwfpik2",
"outputId": "71489303-f472-41b9-ffe5-3874ad4e41f1"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Using your existing Space: https://hf.space/sgbaird/crabnet-hyperparameter 🤗\n",
"\n",
"Loaded as API: https://sgbaird-crabnet-hyperparameter.hf.space ✔\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"result = client.predict(\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x1' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x2' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x3' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0)in 'x4' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x5' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x6' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x7' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x8' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x9' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x10' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x11' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0000000000000002) in 'x12' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x13' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x14' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x15' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x16' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x17' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'x18' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 0.9999999999999998) in 'x19' Slider component\n",
"\t\t0,\t# float (numeric value between 0.0 and 0.9999999999999998) in 'x20' Slider component\n",
"\t\t\"c1_0\",\t# Literal['c1_0', 'c1_1'] in 'c1' Radio component\n",
"\t\t\"c2_0\",\t# Literal['c2_0', 'c2_1'] in 'c2' Radio component\n",
"\t\t\"c3_0\",\t# Literal['c3_0', 'c3_1', 'c3_2'] in 'c3' Radio component\n",
"\t\t0,\t# float (numeric value between 0.0 and 1.0) in 'fidelity1' Slider component\n",
"\t\tapi_name=\"/predict\"\n",
")\n",
"print(result['data'][0])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Y_N0drYMkDyZ",
"outputId": "d4f61f04-bad4-416c-a151-f09b804581e8"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[1.0308926648023387, 1.5329984893334505, 17.54676695304107, 2001617.8575317992]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def simple_predict(x1=0.5, x2=0.5, x3=0.5, x4=0.5, x5=0.5):\n",
" client.predict(\n",
" x1,\t# float (numeric value between 0.0 and 1.0) in 'x1' Slider component\n",
" x2,\t# float (numeric value between 0.0 and 1.0) in 'x2' Slider component\n",
" x3,\t# float (numeric value between 0.0 and 1.0) in 'x3' Slider component\n",
" x4,\t# float (numeric value between 0.0 and 1.0)in 'x4' Slider component\n",
" x5,\t# float (numeric value between 0.0 and 1.0) in 'x5' Slider component\n",
" 0.0,\t# float (numeric value between 0.0 and 1.0) in 'x6' Slider component\n",
" 0.0,\t# float (numeric value between 0.0 and 1.0) in 'x7' Slider component\n",
" 0.5,\t# float (numeric value between 0.0 and 1.0) in 'x8' Slider component\n",
" 0.5,\t# float (numeric value between 0.0 and 1.0) in 'x9' Slider component\n",
" 0.0,\t# float (numeric value between 0.0 and 1.0) in 'x10' Slider component\n",
" 0.0,\t# float (numeric value between 0.0 and 1.0) in 'x11' Slider component\n",
" 0.0,\t# float (numeric value between 0.0 and 1.0000000000000002) in 'x12' Slider component\n",
" 0.5,\t# float (numeric value between 0.0 and 1.0) in 'x13' Slider component\n",
" 0.0,\t# float (numeric value between 0.0 and 1.0) in 'x14' Slider component\n",
" 0.5,\t# float (numeric value between 0.0 and 1.0) in 'x15' Slider component\n",
" 0.0,\t# float (numeric value between 0.0 and 1.0) in 'x16' Slider component\n",
" 0.0,\t# float (numeric value between 0.0 and 1.0) in 'x17' Slider component\n",
" 0.5,\t# float (numeric value between 0.0 and 1.0) in 'x18' Slider component\n",
" 0.0,\t# float (numeric value between 0.0 and 0.9999999999999998) in 'x19' Slider component\n",
" 0.0,\t# float (numeric value between 0.0 and 0.9999999999999998) in 'x20' Slider component\n",
" \"c1_0\",\t# Literal['c1_0', 'c1_1'] in 'c1' Radio component\n",
" \"c2_0\",\t# Literal['c2_0', 'c2_1'] in 'c2' Radio component\n",
" \"c3_0\",\t# Literal['c3_0', 'c3_1', 'c3_2'] in 'c3' Radio component\n",
" 0.5,\t# float (numeric value between 0.0 and 1.0) in 'fidelity1' Slider component\n",
" api_name=\"/predict\"\n",
" )\n",
" y1 = result['data'][0][0]\n",
" return y1"
],
"metadata": {
"id": "waFUMCVikHED"
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"source": [
"y1a = simple_predict()\n",
"print(y1a)\n",
"\n",
"y1b = simple_predict()\n",
"print(y1b)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5YhD8Kb1kvvZ",
"outputId": "fb3ec38b-66b1-42f8-e3b3-6af188dc56bb"
},
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0.6990967070382647\n",
"0.6990967070382647\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"params = {\n",
" \"x1\": 0.5,\n",
" \"x2\": 0.5,\n",
" \"x3\": 0.5,\n",
" \"x4\": 0.5,\n",
" \"x5\": 0.5,\n",
" \"x6\": 0,\n",
" \"x7\": 0,\n",
" \"x8\": 0,\n",
" \"x9\": 0,\n",
" \"x10\": 0,\n",
" \"x11\": 0,\n",
" \"x12\": 0,\n",
" \"x13\": 0,\n",
" \"x14\": 0,\n",
" \"x15\": 0,\n",
" \"x16\": 0,\n",
" \"x17\": 0,\n",
" \"x18\": 0,\n",
" \"x19\": 0,\n",
" \"x20\": 0,\n",
" \"c1\": \"c1_0\",\n",
" \"c2\": \"c2_0\",\n",
" \"c3\": \"c3_0\",\n",
" \"fidelity1\": 0.5\n",
"}\n",
"\n",
"result = client.predict(*params.values())\n",
"print(result)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "J0_AiOSbk6p-",
"outputId": "eac52218-98fc-4778-b1a2-cb660463fcdf"
},
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'headers': ['y1', 'y2', 'y3', 'y4'], 'data': [[0.6990967070382647, 1.2493484815536882, 290.2608641300176, 27954248.007549733]], 'metadata': None}\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "QKX5GjWUoQMR"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment