Created
June 18, 2023 02:12
-
-
Save j40903272/fb5ab7b3e39fd4adabe1c00962b1df71 to your computer and use it in GitHub Desktop.
research paper editor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "raw", | |
"id": "b77c4818-9d92-4581-a2fb-2b38da776a8b", | |
"metadata": {}, | |
"source": [ | |
"# python >= 3.8\n", | |
"import sys\n", | |
"!{sys.executable} -m pip insatll langchain, gradio, tiktoken, unstructured" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "80b2033d-b985-440d-a01f-e7f8547a6801", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/requests/__init__.py:102: RequestsDependencyWarning: urllib3 (1.26.8) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!\n", | |
" warnings.warn(\"urllib3 ({}) or chardet ({})/charset_normalizer ({}) doesn't match a supported \"\n", | |
"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
" from .autonotebook import tqdm as notebook_tqdm\n" | |
] | |
} | |
], | |
"source": [ | |
"# from langchain.text_splitter import LatexTextSplitter\n", | |
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n", | |
"from typing import Any\n", | |
"import requests\n", | |
"import logging\n", | |
"import json\n", | |
"import tiktoken\n", | |
"import gradio as gr\n", | |
"from langchain.document_loaders import UnstructuredPDFLoader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "720927dd-848e-47ff-a847-4e1097768561", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"turbo_encoding = tiktoken.encoding_for_model(\"gpt-3.5-turbo\")\n", | |
"with open(\"sample.tex\", \"r\") as f:\n", | |
" content = f.read()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "13bc48fd-a048-4635-99c4-79ddb62d5c4d", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"class LatexTextSplitter(RecursiveCharacterTextSplitter):\n", | |
" \"\"\"Attempts to split the text along Latex-formatted layout elements.\"\"\"\n", | |
"\n", | |
" def __init__(self, **kwargs: Any):\n", | |
" \"\"\"Initialize a LatexTextSplitter.\"\"\"\n", | |
" separators = [\n", | |
" # First, try to split along Latex sections\n", | |
" \"\\chapter{\",\n", | |
" \"\\section{\",\n", | |
" \"\\subsection{\",\n", | |
" \"\\subsubsection{\",\n", | |
"\n", | |
" # Now split by environments\n", | |
" \"\\begin{\"\n", | |
" # \"\\n\\\\begin{enumerate}\",\n", | |
" # \"\\n\\\\begin{itemize}\",\n", | |
" # \"\\n\\\\begin{description}\",\n", | |
" # \"\\n\\\\begin{list}\",\n", | |
" # \"\\n\\\\begin{quote}\",\n", | |
" # \"\\n\\\\begin{quotation}\",\n", | |
" # \"\\n\\\\begin{verse}\",\n", | |
" # \"\\n\\\\begin{verbatim}\",\n", | |
"\n", | |
" ## Now split by math environments\n", | |
" # \"\\n\\\\begin{align}\",\n", | |
" # \"$$\",\n", | |
" # \"$\",\n", | |
"\n", | |
" # Now split by the normal type of lines\n", | |
" \" \",\n", | |
" \"\",\n", | |
" ]\n", | |
" super().__init__(separators=separators, **kwargs)\n", | |
"\n", | |
"\n", | |
"def json_validator(text: str, openai_key: str, retry: int = 3):\n", | |
" for _ in range(retry):\n", | |
" try:\n", | |
" return json.loads(text)\n", | |
" except Exception:\n", | |
" \n", | |
" try:\n", | |
" prompt = f\"Modify the following into a valid json format:\\n{text}\"\n", | |
" prompt_token_length = len(turbo_encoding.encode(prompt))\n", | |
"\n", | |
" data = {\n", | |
" \"model\": \"text-davinci-003\",\n", | |
" \"prompt\": prompt,\n", | |
" \"max_tokens\": 4097 - prompt_token_length - 64\n", | |
" }\n", | |
" headers = {\n", | |
" \"Content-Type\": \"application/json\",\n", | |
" \"Authorization\": f\"Bearer {openai_key}\"\n", | |
" }\n", | |
" for _ in range(retry):\n", | |
" response = requests.post(\n", | |
" 'https://api.openai.com/v1/completions',\n", | |
" json=data,\n", | |
" headers=headers,\n", | |
" timeout=300\n", | |
" )\n", | |
" if response.status_code != 200:\n", | |
" logging.warning(f'fetch openai chat retry: {response.text}')\n", | |
" continue\n", | |
" text = response.json()['choices'][0]['text']\n", | |
" break\n", | |
" except:\n", | |
" return response.json()['error']\n", | |
" \n", | |
" return text" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "a34bf526-fb0c-4b5e-8c3d-bab7a13e5fe7", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def analyze(latex_whole_document: str, openai_key: str, progress):\n", | |
" \n", | |
" logging.info(\"start analysis\")\n", | |
" \n", | |
" output_format = \"\"\"\n", | |
"\n", | |
" ```json\n", | |
" [\n", | |
" \\\\ Potential point for improvement 1\n", | |
" {{\n", | |
" \"title\": string \\\\ What this modification is about\n", | |
" \"thought\": string \\\\ The reason why this should be improved\n", | |
" \"action\": string \\\\ how to make improvement\n", | |
" \"original\": string \\\\ the original latex snippet that can be improved\n", | |
" \"improved\": string \\\\ the improved latex snippet which address your point\n", | |
" }},\n", | |
" {{}}\n", | |
" ]\n", | |
" ```\n", | |
" \"\"\"\n", | |
" \n", | |
" chunk_size = 1000\n", | |
" # for _ in range(5):\n", | |
" # try:\n", | |
" # latex_splitter = LatexTextSplitter(\n", | |
" # chunk_size=min(chunk_size, len(latex_whole_document)),\n", | |
" # chunk_overlap=0,\n", | |
" # )\n", | |
" # docs = latex_splitter.create_documents([latex_whole_document])\n", | |
" # break\n", | |
" # except:\n", | |
" # chunk_size // 2\n", | |
"\n", | |
" latex_splitter = LatexTextSplitter(\n", | |
" chunk_size=min(chunk_size, len(latex_whole_document)),\n", | |
" chunk_overlap=0,\n", | |
" )\n", | |
" docs = latex_splitter.create_documents([latex_whole_document])\n", | |
" \n", | |
" progress(0.05)\n", | |
" ideas = []\n", | |
" for doc in progress.tqdm(docs):\n", | |
"\n", | |
" prompt = f\"\"\"\n", | |
" ```\n", | |
" {doc.page_content}\n", | |
" ```\n", | |
" I'm a computer science student. The above is my research paper.\n", | |
" You are my editor.\n", | |
" Your goal is to improve the paper quality at your best.\n", | |
" Point out the parts that can be improved.\n", | |
" Focus on grammar, writing, content, section structure.\n", | |
" Ignore comments and use packages.\n", | |
" List out all the points with a latex snippet which is the improved version addressing your point.\n", | |
" Same paragraph should be only address once.\n", | |
" Output the response in the following valid json format:\n", | |
" {output_format}\n", | |
"\n", | |
" \"\"\"\n", | |
" \n", | |
" idea = fetch_chat(prompt, openai_key)\n", | |
" if isinstance(idea, list):\n", | |
" ideas += idea\n", | |
" break\n", | |
" else:\n", | |
" raise gr.Error(idea)\n", | |
"\n", | |
" logging.info('complete analysis')\n", | |
" return ideas\n", | |
"\n", | |
"\n", | |
"def fetch_chat(prompt: str, openai_key: str, retry: int = 3):\n", | |
" json = {\n", | |
" \"model\": \"gpt-3.5-turbo-16k\",\n", | |
" \"messages\": [{\"role\": \"user\", \"content\": prompt}]\n", | |
" }\n", | |
" headers = {\n", | |
" \"Content-Type\": \"application/json\",\n", | |
" \"Authorization\": f\"Bearer {openai_key}\"\n", | |
" }\n", | |
" for _ in range(retry):\n", | |
" response = requests.post(\n", | |
" 'https://api.openai.com/v1/chat/completions',\n", | |
" json=json,\n", | |
" headers=headers,\n", | |
" timeout=300\n", | |
" )\n", | |
" if response.status_code != 200:\n", | |
" logging.warning(f'fetch openai chat retry: {response.text}')\n", | |
" continue\n", | |
" result = response.json()['choices'][0]['message']['content']\n", | |
" return json_validator(result, openai_key)\n", | |
" \n", | |
" return response.json()[\"error\"]\n", | |
" \n", | |
" \n", | |
"def read_file(f: str):\n", | |
" if f is None:\n", | |
" return \"\"\n", | |
" elif f.name.endswith('pdf'):\n", | |
" loader = UnstructuredPDFLoader(f.name)\n", | |
" pages = loader.load_and_split()\n", | |
" return \"\\n\".join([p.page_content for p in pages])\n", | |
" elif f.name.endswith('tex'):\n", | |
" with open(f.name, \"r\") as f:\n", | |
" return f.read()\n", | |
" else:\n", | |
" return \"Only support .tex & .pdf\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "cec63e87-9741-4596-a3f1-a901830e3771", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/components/button.py:112: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n", | |
" warnings.warn(\n", | |
"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/layouts.py:80: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n", | |
" warnings.warn(\n", | |
"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/components/textbox.py:259: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n", | |
" warnings.warn(\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Running on local URL: http://0.0.0.0:7653\n", | |
"Running on public URL: https://9d8a844ccd23f16d95.gradio.live\n", | |
"\n", | |
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<div><iframe src=\"https://9d8a844ccd23f16d95.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"idea_list = []\n", | |
"max_ideas = 20\n", | |
"\n", | |
"\n", | |
"with gr.Blocks() as demo:\n", | |
" \n", | |
" def generate(txt: str, openai_key: str, progress=gr.Progress()):\n", | |
" \n", | |
" if not openai_key:\n", | |
" raise gr.Error(\"Please provide openai key !\")\n", | |
" \n", | |
" try:\n", | |
" global idea_list\n", | |
" idea_list = analyze(txt, openai_key, progress)\n", | |
" k = min(len(idea_list), max_ideas)\n", | |
"\n", | |
" idea_buttons = [\n", | |
" gr.Button.update(visible=True, value=i['title'])\n", | |
" for e, i in enumerate(idea_list[:max_ideas])\n", | |
" ]\n", | |
" idea_buttons += [\n", | |
" gr.Button.update(visible=False)\n", | |
" ]*(max_ideas-len(idea_buttons))\n", | |
"\n", | |
" idea_details = [\n", | |
" gr.Textbox.update(value=\"\", label=\"thought\", visible=True),\n", | |
" gr.Textbox.update(value=\"\", label=\"action\", visible=True),\n", | |
" gr.Textbox.update(value=\"\", label=\"original\", visible=True, max_lines=5, lines=5),\n", | |
" gr.Textbox.update(value=\"\", label=\"improved\", visible=True, max_lines=5, lines=5)\n", | |
" ]\n", | |
"\n", | |
" return [\n", | |
" gr.Textbox.update(label=\"openai_key\"),\n", | |
" gr.Textbox.update(\"Suggestions\", interactive=False, show_label=False),\n", | |
" gr.Button.update(visible=True, value=\"Analyze\")\n", | |
" ] + idea_details + idea_buttons\n", | |
" except Exception as e:\n", | |
" raise gr.Error(str(e))\n", | |
"\n", | |
" def select(name: str):\n", | |
" global idea_list\n", | |
" for i in idea_list:\n", | |
" if i['title'] == name:\n", | |
" return [\n", | |
" gr.Textbox.update(value=i[\"thought\"], label=\"thought\", visible=True),\n", | |
" gr.Textbox.update(value=i[\"action\"], label=\"action\", visible=True),\n", | |
" gr.Textbox.update(value=i[\"original\"], label=\"original\", visible=True, max_lines=5, lines=5),\n", | |
" gr.Textbox.update(value=i[\"improved\"], label=\"improved\", visible=True, max_lines=5, lines=5)\n", | |
" ]\n", | |
" \n", | |
" title = gr.Button(\"PaperGPT\", interactive=False).style(size=10)\n", | |
" key = gr.Textbox(label=\"openai_key\")\n", | |
" with gr.Row().style(equal_height=True):\n", | |
" with gr.Column(scale=0.95):\n", | |
" txt_in = gr.Textbox(label=\"Input\", lines=11, max_lines=11, value=content[2048+2048+256-45:])\n", | |
" with gr.Column(scale=0.05):\n", | |
" upload = gr.File(file_count=\"single\", file_types=[\"tex\", \".pdf\"])\n", | |
" btn = gr.Button(\"Analyze\")\n", | |
" upload.change(read_file, inputs=upload, outputs=txt_in)\n", | |
"\n", | |
" textboxes = []\n", | |
" sug = gr.Textbox(\"Suggestions\", interactive=False, show_label=False).style(text_align=\"center\")\n", | |
" with gr.Row():\n", | |
" with gr.Column(scale=0.4):\n", | |
" for i in range(max_ideas):\n", | |
" t = gr.Button(\"\", visible=False)\n", | |
" textboxes.append(t)\n", | |
" with gr.Column(scale=0.6):\n", | |
" thought = gr.Textbox(label=\"thought\", visible=False, interactive=False)\n", | |
" action = gr.Textbox(label=\"action\", visible=False, interactive=False)\n", | |
" original = gr.Textbox(label=\"original\", visible=False, max_lines=5, lines=5, interactive=False)\n", | |
" improved = gr.Textbox(label=\"improved\", visible=False, max_lines=5, lines=5, interactive=False)\n", | |
"\n", | |
" btn.click(generate, inputs=[txt_in, key], outputs=[key, sug, btn, thought, action, original, improved] + textboxes)\n", | |
" for i in textboxes:\n", | |
" i.click(select, inputs=[i], outputs=[thought, action, original, improved])\n", | |
" demo.launch(server_name=\"0.0.0.0\", server_port=7653, share=True, enable_queue=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"id": "8ac8aa92-f7a6-480c-a1b9-2f1c61426846", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Closing server running on port: 7653\n" | |
] | |
} | |
], | |
"source": [ | |
"demo.close()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c9a19815-b8de-4a99-9fcf-0b1a0d3981a3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.16" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment