Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
3 |
import pandas as pd
|
@@ -6,6 +7,10 @@ from sklearn.linear_model import LinearRegression
|
|
6 |
from io import StringIO
|
7 |
from gradio.themes.base import Base
|
8 |
from gradio.themes.utils import colors, fonts
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Custom theme
|
11 |
custom_theme = Base(
|
@@ -16,7 +21,11 @@ custom_theme = Base(
|
|
16 |
# Load IBM Granite model
|
17 |
model_name = "ibm-granite/granite-3.3-2b-instruct"
|
18 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
19 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
20 |
llm = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
21 |
|
22 |
# Module 1: Policy Summarization
|
@@ -28,7 +37,7 @@ def policy_summarizer_v2(text, file):
|
|
28 |
else:
|
29 |
return "⚠️ Please upload a file or paste some text."
|
30 |
prompt = f"Summarize the following city policy in simple terms:\n{content}\nSummary:"
|
31 |
-
result = llm(prompt, max_new_tokens=
|
32 |
return result.replace(prompt, "").strip()
|
33 |
|
34 |
# Module 2: Citizen Feedback
|
@@ -66,7 +75,7 @@ def detect_anomaly(csv_file):
|
|
66 |
# Module 6: Chat Assistant
|
67 |
def chat_assistant(question):
|
68 |
prompt = f"Answer this smart city sustainability question:\n\nQ: {question}\nA:"
|
69 |
-
result = llm(prompt, max_new_tokens=
|
70 |
return result.replace(prompt, "").strip()
|
71 |
|
72 |
# Gradio App UI
|
@@ -115,4 +124,4 @@ with gr.Blocks(theme=custom_theme) as app:
|
|
115 |
chat_btn = gr.Button("Ask")
|
116 |
chat_btn.click(chat_assistant, inputs=chat_input, outputs=chat_output)
|
117 |
|
118 |
-
app.launch()
|
|
|
1 |
+
|
2 |
import gradio as gr
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
4 |
import pandas as pd
|
|
|
7 |
from io import StringIO
|
8 |
from gradio.themes.base import Base
|
9 |
from gradio.themes.utils import colors, fonts
|
10 |
+
import torch
|
11 |
+
|
12 |
+
# GPU Check (Optional Debug Info)
|
13 |
+
print("✅ Model loading... GPU available:", torch.cuda.is_available())
|
14 |
|
15 |
# Custom theme
|
16 |
custom_theme = Base(
|
|
|
21 |
# Load IBM Granite model
|
22 |
model_name = "ibm-granite/granite-3.3-2b-instruct"
|
23 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
24 |
+
model = AutoModelForCausalLM.from_pretrained(
|
25 |
+
model_name,
|
26 |
+
device_map="auto",
|
27 |
+
torch_dtype=torch.float16 # Faster inference on GPU
|
28 |
+
)
|
29 |
llm = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
30 |
|
31 |
# Module 1: Policy Summarization
|
|
|
37 |
else:
|
38 |
return "⚠️ Please upload a file or paste some text."
|
39 |
prompt = f"Summarize the following city policy in simple terms:\n{content}\nSummary:"
|
40 |
+
result = llm(prompt, max_new_tokens=100)[0]["generated_text"]
|
41 |
return result.replace(prompt, "").strip()
|
42 |
|
43 |
# Module 2: Citizen Feedback
|
|
|
75 |
# Module 6: Chat Assistant
|
76 |
def chat_assistant(question):
|
77 |
prompt = f"Answer this smart city sustainability question:\n\nQ: {question}\nA:"
|
78 |
+
result = llm(prompt, max_new_tokens=100, temperature=0.7)[0]["generated_text"]
|
79 |
return result.replace(prompt, "").strip()
|
80 |
|
81 |
# Gradio App UI
|
|
|
124 |
chat_btn = gr.Button("Ask")
|
125 |
chat_btn.click(chat_assistant, inputs=chat_input, outputs=chat_output)
|
126 |
|
127 |
+
app.launch()
|