This update significantly enhances the "Gramformer Demo" application by modernizing its architecture, improving performance, and refining the user interface. It transitions the application from a class-based structure to a more functional, modular design that is idiomatic for Streamlit, improving readability and maintainability.
The most critical improvement is the update from the legacy @st.cache
to the current @st.cache_resource
for loading the Gramformer model. This ensures the large model is loaded into memory only once and is efficiently reused across sessions, providing a much faster user experience after the initial startup.
The UI has been redesigned into a clean, two-column layout, allowing users to compare the original text and the corrected results side-by-side. Additionally, the code now includes logic to automatically detect and utilize a CUDA-enabled GPU if available, while gracefully falling back to CPU.
sequenceDiagram
actor User
participant App as "Streamlit App"
participant DeviceDetector as "Device Logic"
participant Cache
participant GramformerModel as "Gramformer Model"
User->>App: Starts App
App->>DeviceDetector: get_compute_device()
DeviceDetector-->>App: Returns 'cuda' or 'cpu'
App->>Cache: load_gramformer(device)
alt Model not in Cache
Cache->>GramformerModel: Initialize(device)
GramformerModel-->>Cache: Model Ready
end
Cache-->>App: Returns Cached Model
User->>App: Enters text & clicks "Correct"
App->>GramformerModel: correct(text)
GramformerModel-->>App: Returns Correction & Edits
App-->>User: Displays Input, Output & Analysis
Diagram illustrating the streamlined startup process. The app first detects the optimal hardware, then loads the model via a cache, ensuring efficient resource use and fast performance for the user.
import streamlit as st
import pandas as pd
import torch
from gramformer import Gramformer
from annotated_text import annotated_text
from bs4 import BeautifulSoup
import re
import math
# --- Core Logic & Model Handling ---
def get_compute_device() -> tuple[str, bool]:
"""
Detects and returns the best available compute device and a GPU flag.
Checks for CUDA, falls back to CPU.
Note: Gramformer's `use_gpu` is a boolean for CUDA only. DirectML is not
natively supported by the library's interface.
"""
if torch.cuda.is_available():
return "cuda", True
# To add DirectML support in the future if Gramformer allows device passing:
# try:
# import torch_directml
# if torch_directml.is_available():
# return "privateuseone", True # This would require Gramformer modification
# except ImportError:
# pass
return "cpu", False
@st.cache_resource(show_spinner="Loading grammar model...")
def load_gramformer(model_type: int = 1, use_gpu: bool = False) -> Gramformer:
"""
Loads and caches the Gramformer model to prevent reloading on each run.
"""
# Setting seed for reproducibility, consistent with the original script
torch.manual_seed(1212)
if use_gpu and torch.cuda.is_available():
torch.cuda.manual_seed_all(1212)
# model=1 is the corrector model
return Gramformer(models=model_type, use_gpu=use_gpu)
def correct_text(gf: Gramformer, text: str, max_candidates: int) -> str:
"""
Uses the loaded Gramformer model to correct the input text.
"""
if not text.strip():
return "", ""
results = gf.correct(text, max_candidates=max_candidates)
# gf.correct returns a set of tuples. We need to convert to list to access.
corrected_sentences = list(results)
if not corrected_sentences:
return text, 0.0 # Return original if no correction found
# Each result is a tuple (corrected_sentence, score)
return corrected_sentences[0]
# --- UI Rendering Components ---
def render_highlights(gf: Gramformer, original_text: str, corrected_text: str):
"""
Parses Gramformer's highlight format and renders it using annotated_text.
"""
try:
highlight_html = gf.highlight(original_text, corrected_text)
# Simple sanitizer to ensure all tags are closed properly for parsing
tags = re.findall(r'<([dac])\s', highlight_html)
for tag in tags:
highlight_html = re.sub(f'</{tag}>', f' </{tag}>', highlight_html, 1)
tokens = re.split(r'(<[dac]\s.*?</[dac]>)', highlight_html)
annotations = []
color_map = {'d': '#faa', 'a': '#afa', 'c': '#fea'} # delete, add, correct
for token in filter(None, tokens):
soup = BeautifulSoup(token, 'html.parser')
tag = soup.find(['d', 'a', 'c'])
if tag:
tag_name = tag.name
edit_type = tag.get('type', '')
edit_text = tag.get('edit', tag.text)
color = color_map.get(tag_name, '#fff')
if tag_name == 'd':
# Apply strikethrough for deleted text
edit_text = '\u0336'.join(tag.text) + '\u0336'
annotations.append((edit_text, edit_type, color))
else:
annotations.append(token)
if annotations:
# Dynamically adjust height based on content length
scroll_height = 45 * (math.ceil(len(str(annotations)) / 100))
annotated_text(*annotations, height=scroll_height, scrolling=True)
else:
st.info("No highlightable edits were found.")
except Exception as e:
st.error(f"An error occurred while generating highlights: {e}")
def render_edits_table(gf: Gramformer, original_text: str, corrected_text: str):
"""
Renders the grammatical edits in a pandas DataFrame.
"""
try:
edits = gf.get_edits(original_text, corrected_text)
if edits:
df = pd.DataFrame(edits, columns=['Type', 'Original Token', 'Start', 'End', 'Correct Token', 'Start', 'End'])
df = df.set_index('Type')
st.table(df)
else:
st.info("No grammatical edits were found.")
except Exception as e:
st.error(f"An error occurred while generating the edits table: {e}")
# --- Main Application ---
def main():
"""
The main function that runs the Streamlit application.
"""
st.set_page_config(
page_title="Advanced Gramformer",
layout="wide",
initial_sidebar_state="expanded"
)
# --- Sidebar Configuration ---
st.sidebar.title("Configuration")
model_type_key = st.sidebar.selectbox(
"Choose Model",
['Corrector', 'Detector - coming soon'],
help="Select the grammar model to use. Only 'Corrector' is currently available."
)
max_candidates = 1
if model_type_key == 'Corrector':
max_candidates = st.sidebar.number_input(
"Max Correction Candidates",
min_value=1, max_value=10, value=1,
help="The number of correction suggestions to generate."
)
else:
st.sidebar.warning("Detector model is not yet implemented.")
# --- Load Model ---
_ , use_gpu = get_compute_device()
gf = load_gramformer(use_gpu=use_gpu)
# --- Main Page Layout ---
st.title("Advanced Gramformer Demo")
st.markdown("A framework for detecting, highlighting, and correcting grammatical errors.")
examples = [
"what be the reason for everyone leave the comapny", "He are moving here.",
"I am doing fine. How is you?", "How is they?", "Matt like fish",
"the collection of letters was original used by the ancient Romans",
"We enjoys horror movies", "Anna and Mike is going skiing",
"I walk to the store and I bought milk", "We all eat the fish and then made dessert",
]
col1, col2 = st.columns([1, 1], gap="medium")
with col1:
st.subheader("Original Text")
example_selection = st.selectbox("Choose an example or write your own:", examples)
input_text = st.text_area(
"Enter text here:",
value=example_selection,
height=300,
label_visibility="collapsed"
)
with col2:
st.subheader("Analysis & Correction")
if st.button("Correct Text", use_container_width=True, type="primary"):
if model_type_key != 'Corrector' or not input_text.strip():
st.warning("Please select the 'Corrector' model and enter some text.")
else:
with st.spinner("Analyzing..."):
corrected_sentence, _ = correct_text(gf, input_text, max_candidates)
st.markdown("##### Corrected Sentence:")
st.success(corrected_sentence)
with st.expander("Show Highlights", expanded=True):
render_highlights(gf, input_text, corrected_sentence)
with st.expander("Show Edits Table"):
render_edits_table(gf, input_text, corrected_sentence)
if __name__ == "__main__":
main()
streamlit
gramformer
pandas
beautifulsoup4
annotated-text
torch
# For AMD/Intel GPU support on Windows, uncomment the line below and install
# torch-directml
Category | Description | Affected Files/Areas |
---|---|---|
Architecture & Core |
|
app.py |
Performance |
|
app.py |
UI/UX |
|
app.py |
Code Quality |
|
app.py |
Dependencies |
|
requirements.txt |
Setup & Usage |
|
- |