File size: 2,350 Bytes
ed3696e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import streamlit as st
from transformers import pipeline
from diff_match_patch import diff_match_patch

# Load grammar correction pipeline
@st.cache_resource
def load_grammar_model():
    return pipeline("text2text-generation", model="vennify/t5-base-grammar-correction")

# Optional: load explanation model (like flan-t5)
@st.cache_resource
def load_explainer_model():
    return pipeline("text2text-generation", model="google/flan-t5-large")

grammar_model = load_grammar_model()
explainer_model = load_explainer_model()
dmp = diff_match_patch()

st.title("Grammarly-like AI Writing Assistant")
st.markdown("Fix grammar, punctuation, spelling, tenses — with explanations and tips!")

# User input
user_input = st.text_area("Enter your sentence, paragraph, or essay:", height=200)

if st.button("Correct Grammar"):
    if user_input.strip():
        # Correct the input
        output = grammar_model(f"grammar: {user_input}", max_length=512, do_sample=False)[0]["generated_text"]
        st.subheader("Corrected Text")
        st.success(output)

        # Show word-by-word diff
        st.subheader("Changes Highlighted")
        diffs = dmp.diff_main(user_input, output)
        dmp.diff_cleanupSemantic(diffs)
        html_diff = ""
        for (op, data) in diffs:
            if op == -1:
                html_diff += f'<span style="background-color:#fbb;">{data}</span>'
            elif op == 1:
                html_diff += f'<span style="background-color:#bfb;">{data}</span>'
            else:
                html_diff += data
        st.markdown(f"<div style='font-family:monospace;'>{html_diff}</div>", unsafe_allow_html=True)

        # Explanation
        if st.button("Explain Corrections"):
            explanation_prompt = f"Explain the grammar issues in this text and how it was improved: {user_input}"
            explanation = explainer_model(explanation_prompt, max_length=200)[0]['generated_text']
            st.subheader("Explanation")
            st.info(explanation)

        # Suggest improvements
        if st.button("Suggest Improvements"):
            suggest_prompt = f"Suggest improvements to make this writing more professional: {output}"
            suggestions = explainer_model(suggest_prompt, max_length=200)[0]['generated_text']
            st.subheader("Suggestions")
            st.warning(suggestions)