cpv2280 commited on
Commit
9b4424e
Β·
verified Β·
1 Parent(s): 845e915

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -77
app.py CHANGED
@@ -1,77 +1,77 @@
1
-
2
-
3
-
4
- import gradio as gr
5
- import ftfy
6
- import language_tool_python
7
- import re
8
- from sentence_transformers import SentenceTransformer, util
9
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
10
-
11
- # Load fine-tuned GPT-2 model
12
- model_path = "/content/drive/MyDrive/gpt2_tinystories_finetuned" # Update if needed
13
- model = AutoModelForCausalLM.from_pretrained(model_path)
14
- tokenizer = AutoTokenizer.from_pretrained(model_path)
15
-
16
- # Create a text-generation pipeline
17
- story_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
18
-
19
- # Load NLP tools
20
- tool = language_tool_python.LanguageTool('en-UK')
21
- sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
22
-
23
- def refine_story(text):
24
- """Refines the generated story by fixing encoding, grammar, and redundancy."""
25
- text = ftfy.fix_text(text) # Fix encoding
26
- matches = tool.check(text) # Check grammar
27
- text = language_tool_python.utils.correct(text, matches) # Apply fixes
28
-
29
- # Remove redundant words/phrases
30
- text = re.sub(r'(\b\w+\b) \1', r'\1', text) # Remove duplicate words
31
- text = re.sub(r'(\b\w+ and \w+\b)(,? \1)+', r'\1', text) # Remove phrase repetitions
32
-
33
- return text
34
-
35
- def detect_inconsistencies(text):
36
- """Checks for logical inconsistencies by comparing sentence similarities."""
37
- sentences = text.split(". ")
38
- inconsistencies = []
39
-
40
- # Compare each sentence with the next one
41
- for i in range(len(sentences) - 1):
42
- similarity_score = util.pytorch_cos_sim(sentence_model.encode(sentences[i]), sentence_model.encode(sentences[i+1]))
43
-
44
- if similarity_score.item() < 0.3: # If similarity is low, flag as inconsistent
45
- inconsistencies.append(f"⚠️ **Possible inconsistency detected:**\n➑ {sentences[i]} \n➑ {sentences[i+1]}")
46
-
47
- return "\n\n".join(inconsistencies) if inconsistencies else "βœ… No major inconsistencies detected."
48
-
49
- def story_pipeline(prompt):
50
- """Generates a story, refines it, and checks inconsistencies."""
51
- # Generate the story
52
- generated = story_generator(prompt, max_length=200, do_sample=True, temperature=1.0, top_p=0.9, top_k=50)
53
- raw_story = generated[0]['generated_text']
54
-
55
- # Refine the generated story
56
- refined_story = refine_story(raw_story)
57
-
58
- # Detect logical inconsistencies
59
- inconsistencies = detect_inconsistencies(refined_story)
60
-
61
- return raw_story, refined_story, inconsistencies
62
-
63
- # βœ… Gradio Interface with Proper Logical Inconsistency Detection
64
- interface = gr.Interface(
65
- fn=story_pipeline,
66
- inputs=gr.Textbox(label="Enter Story Prompt", placeholder="Once upon a time..."),
67
- outputs=[
68
- gr.Textbox(label="πŸ“– Generated Story", interactive=True), # Interactive textbox
69
- gr.Textbox(label="βœ… Refined Story", interactive=True), # Refined output
70
- gr.Textbox(label="⚠️ Logical Inconsistencies", interactive=False), # Shows inconsistencies correctly
71
- ],
72
- title="πŸ“– FableWeaver AI",
73
- description="Generates AI-powered TinyStories using GPT-2 fine-tuned on TinyStories. Automatically refines the story and detects logical inconsistencies."
74
- )
75
-
76
- # Launch Gradio app
77
- interface.launch(share="True")
 
1
+
2
+
3
+
4
+ import gradio as gr
5
+ import ftfy
6
+ import language_tool_python
7
+ import re
8
+ from sentence_transformers import SentenceTransformer, util
9
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
10
+
11
+ # Load fine-tuned GPT-2 model
12
+ model_path = "cpv2280/gpt2_tinystories_finetuned" # Update if needed
13
+ model = AutoModelForCausalLM.from_pretrained(model_path)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
15
+
16
+ # Create a text-generation pipeline
17
+ story_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
18
+
19
+ # Load NLP tools
20
+ tool = language_tool_python.LanguageTool('en-UK')
21
+ sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
22
+
23
+ def refine_story(text):
24
+ """Refines the generated story by fixing encoding, grammar, and redundancy."""
25
+ text = ftfy.fix_text(text) # Fix encoding
26
+ matches = tool.check(text) # Check grammar
27
+ text = language_tool_python.utils.correct(text, matches) # Apply fixes
28
+
29
+ # Remove redundant words/phrases
30
+ text = re.sub(r'(\b\w+\b) \1', r'\1', text) # Remove duplicate words
31
+ text = re.sub(r'(\b\w+ and \w+\b)(,? \1)+', r'\1', text) # Remove phrase repetitions
32
+
33
+ return text
34
+
35
+ def detect_inconsistencies(text):
36
+ """Checks for logical inconsistencies by comparing sentence similarities."""
37
+ sentences = text.split(". ")
38
+ inconsistencies = []
39
+
40
+ # Compare each sentence with the next one
41
+ for i in range(len(sentences) - 1):
42
+ similarity_score = util.pytorch_cos_sim(sentence_model.encode(sentences[i]), sentence_model.encode(sentences[i+1]))
43
+
44
+ if similarity_score.item() < 0.3: # If similarity is low, flag as inconsistent
45
+ inconsistencies.append(f"⚠️ **Possible inconsistency detected:**\n➑ {sentences[i]} \n➑ {sentences[i+1]}")
46
+
47
+ return "\n\n".join(inconsistencies) if inconsistencies else "βœ… No major inconsistencies detected."
48
+
49
+ def story_pipeline(prompt):
50
+ """Generates a story, refines it, and checks inconsistencies."""
51
+ # Generate the story
52
+ generated = story_generator(prompt, max_length=200, do_sample=True, temperature=1.0, top_p=0.9, top_k=50)
53
+ raw_story = generated[0]['generated_text']
54
+
55
+ # Refine the generated story
56
+ refined_story = refine_story(raw_story)
57
+
58
+ # Detect logical inconsistencies
59
+ inconsistencies = detect_inconsistencies(refined_story)
60
+
61
+ return raw_story, refined_story, inconsistencies
62
+
63
+ # βœ… Gradio Interface with Proper Logical Inconsistency Detection
64
+ interface = gr.Interface(
65
+ fn=story_pipeline,
66
+ inputs=gr.Textbox(label="Enter Story Prompt", placeholder="Once upon a time..."),
67
+ outputs=[
68
+ gr.Textbox(label="πŸ“– Generated Story", interactive=True), # Interactive textbox
69
+ gr.Textbox(label="βœ… Refined Story", interactive=True), # Refined output
70
+ gr.Textbox(label="⚠️ Logical Inconsistencies", interactive=False), # Shows inconsistencies correctly
71
+ ],
72
+ title="πŸ“– FableWeaver AI",
73
+ description="Generates AI-powered TinyStories using GPT-2 fine-tuned on TinyStories. Automatically refines the story and detects logical inconsistencies."
74
+ )
75
+
76
+ # Launch Gradio app
77
+ interface.launch(share="True")