cpv2280 commited on
Commit
845e915
·
verified ·
1 Parent(s): 6a6b31d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import gradio as gr
5
+ import ftfy
6
+ import language_tool_python
7
+ import re
8
+ from sentence_transformers import SentenceTransformer, util
9
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
10
+
11
+ # Load fine-tuned GPT-2 model
12
+ model_path = "/content/drive/MyDrive/gpt2_tinystories_finetuned" # Update if needed
13
+ model = AutoModelForCausalLM.from_pretrained(model_path)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
15
+
16
+ # Create a text-generation pipeline
17
+ story_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
18
+
19
+ # Load NLP tools
20
+ tool = language_tool_python.LanguageTool('en-UK')
21
+ sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
22
+
23
+ def refine_story(text):
24
+ """Refines the generated story by fixing encoding, grammar, and redundancy."""
25
+ text = ftfy.fix_text(text) # Fix encoding
26
+ matches = tool.check(text) # Check grammar
27
+ text = language_tool_python.utils.correct(text, matches) # Apply fixes
28
+
29
+ # Remove redundant words/phrases
30
+ text = re.sub(r'(\b\w+\b) \1', r'\1', text) # Remove duplicate words
31
+ text = re.sub(r'(\b\w+ and \w+\b)(,? \1)+', r'\1', text) # Remove phrase repetitions
32
+
33
+ return text
34
+
35
+ def detect_inconsistencies(text):
36
+ """Checks for logical inconsistencies by comparing sentence similarities."""
37
+ sentences = text.split(". ")
38
+ inconsistencies = []
39
+
40
+ # Compare each sentence with the next one
41
+ for i in range(len(sentences) - 1):
42
+ similarity_score = util.pytorch_cos_sim(sentence_model.encode(sentences[i]), sentence_model.encode(sentences[i+1]))
43
+
44
+ if similarity_score.item() < 0.3: # If similarity is low, flag as inconsistent
45
+ inconsistencies.append(f"⚠️ **Possible inconsistency detected:**\n➡ {sentences[i]} \n➡ {sentences[i+1]}")
46
+
47
+ return "\n\n".join(inconsistencies) if inconsistencies else "✅ No major inconsistencies detected."
48
+
49
+ def story_pipeline(prompt):
50
+ """Generates a story, refines it, and checks inconsistencies."""
51
+ # Generate the story
52
+ generated = story_generator(prompt, max_length=200, do_sample=True, temperature=1.0, top_p=0.9, top_k=50)
53
+ raw_story = generated[0]['generated_text']
54
+
55
+ # Refine the generated story
56
+ refined_story = refine_story(raw_story)
57
+
58
+ # Detect logical inconsistencies
59
+ inconsistencies = detect_inconsistencies(refined_story)
60
+
61
+ return raw_story, refined_story, inconsistencies
62
+
63
+ # ✅ Gradio Interface with Proper Logical Inconsistency Detection
64
+ interface = gr.Interface(
65
+ fn=story_pipeline,
66
+ inputs=gr.Textbox(label="Enter Story Prompt", placeholder="Once upon a time..."),
67
+ outputs=[
68
+ gr.Textbox(label="📖 Generated Story", interactive=True), # Interactive textbox
69
+ gr.Textbox(label="✅ Refined Story", interactive=True), # Refined output
70
+ gr.Textbox(label="⚠️ Logical Inconsistencies", interactive=False), # Shows inconsistencies correctly
71
+ ],
72
+ title="📖 FableWeaver AI",
73
+ description="Generates AI-powered TinyStories using GPT-2 fine-tuned on TinyStories. Automatically refines the story and detects logical inconsistencies."
74
+ )
75
+
76
+ # Launch Gradio app
77
+ interface.launch(share="True")