Skip to content

Instantly share code, notes, and snippets.

@j40903272
Created June 18, 2023 02:12
Show Gist options
  • Save j40903272/fb5ab7b3e39fd4adabe1c00962b1df71 to your computer and use it in GitHub Desktop.
Save j40903272/fb5ab7b3e39fd4adabe1c00962b1df71 to your computer and use it in GitHub Desktop.
research paper editor
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "raw",
"id": "b77c4818-9d92-4581-a2fb-2b38da776a8b",
"metadata": {},
"source": [
"# python >= 3.8\n",
"import sys\n",
"!{sys.executable} -m pip insatll langchain, gradio, tiktoken, unstructured"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "80b2033d-b985-440d-a01f-e7f8547a6801",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/requests/__init__.py:102: RequestsDependencyWarning: urllib3 (1.26.8) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!\n",
" warnings.warn(\"urllib3 ({}) or chardet ({})/charset_normalizer ({}) doesn't match a supported \"\n",
"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"# from langchain.text_splitter import LatexTextSplitter\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from typing import Any\n",
"import requests\n",
"import logging\n",
"import json\n",
"import tiktoken\n",
"import gradio as gr\n",
"from langchain.document_loaders import UnstructuredPDFLoader"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "720927dd-848e-47ff-a847-4e1097768561",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"turbo_encoding = tiktoken.encoding_for_model(\"gpt-3.5-turbo\")\n",
"with open(\"sample.tex\", \"r\") as f:\n",
" content = f.read()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "13bc48fd-a048-4635-99c4-79ddb62d5c4d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class LatexTextSplitter(RecursiveCharacterTextSplitter):\n",
" \"\"\"Attempts to split the text along Latex-formatted layout elements.\"\"\"\n",
"\n",
" def __init__(self, **kwargs: Any):\n",
" \"\"\"Initialize a LatexTextSplitter.\"\"\"\n",
" separators = [\n",
" # First, try to split along Latex sections\n",
" \"\\chapter{\",\n",
" \"\\section{\",\n",
" \"\\subsection{\",\n",
" \"\\subsubsection{\",\n",
"\n",
" # Now split by environments\n",
" \"\\begin{\"\n",
" # \"\\n\\\\begin{enumerate}\",\n",
" # \"\\n\\\\begin{itemize}\",\n",
" # \"\\n\\\\begin{description}\",\n",
" # \"\\n\\\\begin{list}\",\n",
" # \"\\n\\\\begin{quote}\",\n",
" # \"\\n\\\\begin{quotation}\",\n",
" # \"\\n\\\\begin{verse}\",\n",
" # \"\\n\\\\begin{verbatim}\",\n",
"\n",
" ## Now split by math environments\n",
" # \"\\n\\\\begin{align}\",\n",
" # \"$$\",\n",
" # \"$\",\n",
"\n",
" # Now split by the normal type of lines\n",
" \" \",\n",
" \"\",\n",
" ]\n",
" super().__init__(separators=separators, **kwargs)\n",
"\n",
"\n",
"def json_validator(text: str, openai_key: str, retry: int = 3):\n",
" for _ in range(retry):\n",
" try:\n",
" return json.loads(text)\n",
" except Exception:\n",
" \n",
" try:\n",
" prompt = f\"Modify the following into a valid json format:\\n{text}\"\n",
" prompt_token_length = len(turbo_encoding.encode(prompt))\n",
"\n",
" data = {\n",
" \"model\": \"text-davinci-003\",\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": 4097 - prompt_token_length - 64\n",
" }\n",
" headers = {\n",
" \"Content-Type\": \"application/json\",\n",
" \"Authorization\": f\"Bearer {openai_key}\"\n",
" }\n",
" for _ in range(retry):\n",
" response = requests.post(\n",
" 'https://api.openai.com/v1/completions',\n",
" json=data,\n",
" headers=headers,\n",
" timeout=300\n",
" )\n",
" if response.status_code != 200:\n",
" logging.warning(f'fetch openai chat retry: {response.text}')\n",
" continue\n",
" text = response.json()['choices'][0]['text']\n",
" break\n",
" except:\n",
" return response.json()['error']\n",
" \n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a34bf526-fb0c-4b5e-8c3d-bab7a13e5fe7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def analyze(latex_whole_document: str, openai_key: str, progress):\n",
" \n",
" logging.info(\"start analysis\")\n",
" \n",
" output_format = \"\"\"\n",
"\n",
" ```json\n",
" [\n",
" \\\\ Potential point for improvement 1\n",
" {{\n",
" \"title\": string \\\\ What this modification is about\n",
" \"thought\": string \\\\ The reason why this should be improved\n",
" \"action\": string \\\\ how to make improvement\n",
" \"original\": string \\\\ the original latex snippet that can be improved\n",
" \"improved\": string \\\\ the improved latex snippet which address your point\n",
" }},\n",
" {{}}\n",
" ]\n",
" ```\n",
" \"\"\"\n",
" \n",
" chunk_size = 1000\n",
" # for _ in range(5):\n",
" # try:\n",
" # latex_splitter = LatexTextSplitter(\n",
" # chunk_size=min(chunk_size, len(latex_whole_document)),\n",
" # chunk_overlap=0,\n",
" # )\n",
" # docs = latex_splitter.create_documents([latex_whole_document])\n",
" # break\n",
" # except:\n",
" # chunk_size // 2\n",
"\n",
" latex_splitter = LatexTextSplitter(\n",
" chunk_size=min(chunk_size, len(latex_whole_document)),\n",
" chunk_overlap=0,\n",
" )\n",
" docs = latex_splitter.create_documents([latex_whole_document])\n",
" \n",
" progress(0.05)\n",
" ideas = []\n",
" for doc in progress.tqdm(docs):\n",
"\n",
" prompt = f\"\"\"\n",
" ```\n",
" {doc.page_content}\n",
" ```\n",
" I'm a computer science student. The above is my research paper.\n",
" You are my editor.\n",
" Your goal is to improve the paper quality at your best.\n",
" Point out the parts that can be improved.\n",
" Focus on grammar, writing, content, section structure.\n",
" Ignore comments and use packages.\n",
" List out all the points with a latex snippet which is the improved version addressing your point.\n",
" Same paragraph should be only address once.\n",
" Output the response in the following valid json format:\n",
" {output_format}\n",
"\n",
" \"\"\"\n",
" \n",
" idea = fetch_chat(prompt, openai_key)\n",
" if isinstance(idea, list):\n",
" ideas += idea\n",
" break\n",
" else:\n",
" raise gr.Error(idea)\n",
"\n",
" logging.info('complete analysis')\n",
" return ideas\n",
"\n",
"\n",
"def fetch_chat(prompt: str, openai_key: str, retry: int = 3):\n",
" json = {\n",
" \"model\": \"gpt-3.5-turbo-16k\",\n",
" \"messages\": [{\"role\": \"user\", \"content\": prompt}]\n",
" }\n",
" headers = {\n",
" \"Content-Type\": \"application/json\",\n",
" \"Authorization\": f\"Bearer {openai_key}\"\n",
" }\n",
" for _ in range(retry):\n",
" response = requests.post(\n",
" 'https://api.openai.com/v1/chat/completions',\n",
" json=json,\n",
" headers=headers,\n",
" timeout=300\n",
" )\n",
" if response.status_code != 200:\n",
" logging.warning(f'fetch openai chat retry: {response.text}')\n",
" continue\n",
" result = response.json()['choices'][0]['message']['content']\n",
" return json_validator(result, openai_key)\n",
" \n",
" return response.json()[\"error\"]\n",
" \n",
" \n",
"def read_file(f: str):\n",
" if f is None:\n",
" return \"\"\n",
" elif f.name.endswith('pdf'):\n",
" loader = UnstructuredPDFLoader(f.name)\n",
" pages = loader.load_and_split()\n",
" return \"\\n\".join([p.page_content for p in pages])\n",
" elif f.name.endswith('tex'):\n",
" with open(f.name, \"r\") as f:\n",
" return f.read()\n",
" else:\n",
" return \"Only support .tex & .pdf\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cec63e87-9741-4596-a3f1-a901830e3771",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/components/button.py:112: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n",
" warnings.warn(\n",
"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/layouts.py:80: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n",
" warnings.warn(\n",
"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/components/textbox.py:259: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on local URL: http://0.0.0.0:7653\n",
"Running on public URL: https://9d8a844ccd23f16d95.gradio.live\n",
"\n",
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"https://9d8a844ccd23f16d95.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"idea_list = []\n",
"max_ideas = 20\n",
"\n",
"\n",
"with gr.Blocks() as demo:\n",
" \n",
" def generate(txt: str, openai_key: str, progress=gr.Progress()):\n",
" \n",
" if not openai_key:\n",
" raise gr.Error(\"Please provide openai key !\")\n",
" \n",
" try:\n",
" global idea_list\n",
" idea_list = analyze(txt, openai_key, progress)\n",
" k = min(len(idea_list), max_ideas)\n",
"\n",
" idea_buttons = [\n",
" gr.Button.update(visible=True, value=i['title'])\n",
" for e, i in enumerate(idea_list[:max_ideas])\n",
" ]\n",
" idea_buttons += [\n",
" gr.Button.update(visible=False)\n",
" ]*(max_ideas-len(idea_buttons))\n",
"\n",
" idea_details = [\n",
" gr.Textbox.update(value=\"\", label=\"thought\", visible=True),\n",
" gr.Textbox.update(value=\"\", label=\"action\", visible=True),\n",
" gr.Textbox.update(value=\"\", label=\"original\", visible=True, max_lines=5, lines=5),\n",
" gr.Textbox.update(value=\"\", label=\"improved\", visible=True, max_lines=5, lines=5)\n",
" ]\n",
"\n",
" return [\n",
" gr.Textbox.update(label=\"openai_key\"),\n",
" gr.Textbox.update(\"Suggestions\", interactive=False, show_label=False),\n",
" gr.Button.update(visible=True, value=\"Analyze\")\n",
" ] + idea_details + idea_buttons\n",
" except Exception as e:\n",
" raise gr.Error(str(e))\n",
"\n",
" def select(name: str):\n",
" global idea_list\n",
" for i in idea_list:\n",
" if i['title'] == name:\n",
" return [\n",
" gr.Textbox.update(value=i[\"thought\"], label=\"thought\", visible=True),\n",
" gr.Textbox.update(value=i[\"action\"], label=\"action\", visible=True),\n",
" gr.Textbox.update(value=i[\"original\"], label=\"original\", visible=True, max_lines=5, lines=5),\n",
" gr.Textbox.update(value=i[\"improved\"], label=\"improved\", visible=True, max_lines=5, lines=5)\n",
" ]\n",
" \n",
" title = gr.Button(\"PaperGPT\", interactive=False).style(size=10)\n",
" key = gr.Textbox(label=\"openai_key\")\n",
" with gr.Row().style(equal_height=True):\n",
" with gr.Column(scale=0.95):\n",
" txt_in = gr.Textbox(label=\"Input\", lines=11, max_lines=11, value=content[2048+2048+256-45:])\n",
" with gr.Column(scale=0.05):\n",
" upload = gr.File(file_count=\"single\", file_types=[\"tex\", \".pdf\"])\n",
" btn = gr.Button(\"Analyze\")\n",
" upload.change(read_file, inputs=upload, outputs=txt_in)\n",
"\n",
" textboxes = []\n",
" sug = gr.Textbox(\"Suggestions\", interactive=False, show_label=False).style(text_align=\"center\")\n",
" with gr.Row():\n",
" with gr.Column(scale=0.4):\n",
" for i in range(max_ideas):\n",
" t = gr.Button(\"\", visible=False)\n",
" textboxes.append(t)\n",
" with gr.Column(scale=0.6):\n",
" thought = gr.Textbox(label=\"thought\", visible=False, interactive=False)\n",
" action = gr.Textbox(label=\"action\", visible=False, interactive=False)\n",
" original = gr.Textbox(label=\"original\", visible=False, max_lines=5, lines=5, interactive=False)\n",
" improved = gr.Textbox(label=\"improved\", visible=False, max_lines=5, lines=5, interactive=False)\n",
"\n",
" btn.click(generate, inputs=[txt_in, key], outputs=[key, sug, btn, thought, action, original, improved] + textboxes)\n",
" for i in textboxes:\n",
" i.click(select, inputs=[i], outputs=[thought, action, original, improved])\n",
" demo.launch(server_name=\"0.0.0.0\", server_port=7653, share=True, enable_queue=True)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "8ac8aa92-f7a6-480c-a1b9-2f1c61426846",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Closing server running on port: 7653\n"
]
}
],
"source": [
"demo.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9a19815-b8de-4a99-9fcf-0b1a0d3981a3",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment