Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
import torch | |
# Load grammar correction model | |
def load_grammar_model(): | |
model_name = "prithivida/grammar_error_correcter_v1" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
return tokenizer, model | |
# Load explanation model | |
def load_explainer(): | |
explainer = pipeline("text2text-generation", model="google/flan-t5-base", max_length=256) | |
return explainer | |
grammar_tokenizer, grammar_model = load_grammar_model() | |
explanation_model = load_explainer() | |
# Grammar correction function | |
def correct_grammar(text): | |
input_text = "gec: " + text | |
inputs = grammar_tokenizer.encode(input_text, return_tensors="pt", truncation=True) | |
outputs = grammar_model.generate(inputs, max_length=512, num_beams=4, early_stopping=True) | |
corrected_text = grammar_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return corrected_text | |
# Explanation function using second model | |
def explain_correction(original, corrected): | |
prompt = f"Explain the grammar improvements made from: \"{original}\" to: \"{corrected}\"" | |
result = explanation_model(prompt)[0]['generated_text'] | |
return result | |
# Streamlit App UI | |
st.title("π Smart Grammar Correction with Explanations") | |
st.write("Enter your text below. The AI will correct grammar **and explain why** the changes were made using grammar principles.") | |
user_input = st.text_area("Your Text", height=200, placeholder="Type or paste your text here...") | |
if st.button("Correct and Explain"): | |
if user_input.strip(): | |
with st.spinner("Correcting grammar..."): | |
corrected = correct_grammar(user_input) | |
with st.spinner("Explaining corrections..."): | |
explanation = explain_correction(user_input, corrected) | |
st.subheader("β Corrected Text") | |
st.success(corrected) | |
st.subheader("π Explanation (Why it was changed)") | |
st.markdown(explanation) | |
else: | |
st.warning("Please enter some text to correct.") | |