Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
from diff_match_patch import diff_match_patch | |
# Load grammar correction pipeline | |
def load_grammar_model(): | |
return pipeline("text2text-generation", model="vennify/t5-base-grammar-correction") | |
# Optional: load explanation model (like flan-t5) | |
def load_explainer_model(): | |
return pipeline("text2text-generation", model="google/flan-t5-large") | |
grammar_model = load_grammar_model() | |
explainer_model = load_explainer_model() | |
dmp = diff_match_patch() | |
st.title("Grammarly-like AI Writing Assistant") | |
st.markdown("Fix grammar, punctuation, spelling, tenses — with explanations and tips!") | |
# User input | |
user_input = st.text_area("Enter your sentence, paragraph, or essay:", height=200) | |
if st.button("Correct Grammar"): | |
if user_input.strip(): | |
# Correct the input | |
output = grammar_model(f"grammar: {user_input}", max_length=512, do_sample=False)[0]["generated_text"] | |
st.subheader("Corrected Text") | |
st.success(output) | |
# Show word-by-word diff | |
st.subheader("Changes Highlighted") | |
diffs = dmp.diff_main(user_input, output) | |
dmp.diff_cleanupSemantic(diffs) | |
html_diff = "" | |
for (op, data) in diffs: | |
if op == -1: | |
html_diff += f'<span style="background-color:#fbb;">{data}</span>' | |
elif op == 1: | |
html_diff += f'<span style="background-color:#bfb;">{data}</span>' | |
else: | |
html_diff += data | |
st.markdown(f"<div style='font-family:monospace;'>{html_diff}</div>", unsafe_allow_html=True) | |
# Explanation | |
if st.button("Explain Corrections"): | |
explanation_prompt = f"Explain the grammar issues in this text and how it was improved: {user_input}" | |
explanation = explainer_model(explanation_prompt, max_length=200)[0]['generated_text'] | |
st.subheader("Explanation") | |
st.info(explanation) | |
# Suggest improvements | |
if st.button("Suggest Improvements"): | |
suggest_prompt = f"Suggest improvements to make this writing more professional: {output}" | |
suggestions = explainer_model(suggest_prompt, max_length=200)[0]['generated_text'] | |
st.subheader("Suggestions") | |
st.warning(suggestions) | |