yongyeol commited on
Commit
f010769
Β·
verified Β·
1 Parent(s): 8d2a103

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +38 -22
src/streamlit_app.py CHANGED
@@ -3,9 +3,10 @@ import json
3
  import requests
4
  import streamlit as st
5
  from datetime import datetime
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
7
 
8
- # βœ… μ•ˆμ „ν•œ μΊμ‹œ 경둜 μ„€μ • (μ΅œμƒλ‹¨ ν•„μˆ˜)
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
  os.environ["HF_HOME"] = "/tmp/hf_cache"
11
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
@@ -15,23 +16,25 @@ st.set_page_config(page_title="학사일정 μΊ˜λ¦°λ”", layout="centered")
15
  st.title("πŸ“… 학사일정 μΊ˜λ¦°λ” + AI μš”μ•½")
16
  st.markdown("NEIS APIμ—μ„œ 학사일정을 뢈러였고 FullCalendar둜 μ‹œκ°ν™”ν•©λ‹ˆλ‹€.")
17
 
18
- # βœ… 디버깅 좜λ ₯
19
- token_present = os.environ.get("HUGGINGFACE_TOKEN") is not None
20
- st.write("πŸ” 토큰 있음:", token_present)
21
- st.write("βœ… μΊμ‹œ 경둜:", os.environ.get("TRANSFORMERS_CACHE"))
22
-
23
- # βœ… Gemma λͺ¨λΈ λ‘œλ”© ν•¨μˆ˜
24
  @st.cache_resource
25
  def load_model():
26
  token = os.environ.get("HUGGINGFACE_TOKEN")
27
- model_id = "google/gemma-2-2b-it"
28
  cache_dir = "/tmp/hf_cache"
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token, cache_dir=cache_dir)
31
- model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, cache_dir=cache_dir)
32
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
33
 
34
- llm = load_model()
35
 
36
  # βœ… 학ꡐ 정보 κ°€μ Έμ˜€κΈ°
37
  def get_school_info(region_code, school_name, api_key):
@@ -41,7 +44,7 @@ def get_school_info(region_code, school_name, api_key):
41
  school = data.get("schoolInfo", [{}])[1].get("row", [{}])[0]
42
  return school.get("SD_SCHUL_CODE"), school.get("ATPT_OFCDC_SC_CODE")
43
 
44
- # βœ… 학사일정 κ°€μ Έμ˜€κΈ° (μ›” λ‹¨μœ„)
45
  def get_schedule(region_code, school_code, year, month, api_key):
46
  from_ymd = f"{year}{month:02}01"
47
  to_ymd = f"{year}{month:02}31"
@@ -49,13 +52,13 @@ def get_schedule(region_code, school_code, year, month, api_key):
49
  res = requests.get(url)
50
  data = res.json()
51
  rows = data.get("SchoolSchedule", [{}])[1].get("row", [])
52
- st.write("πŸ“¦ 뢈러온 일정 raw data:", rows)
53
  return rows
54
 
55
  # βœ… μš”μ•½ 생성
56
  def summarize_schedule(rows, school_name, year):
57
  if not rows:
58
  return "일정이 μ—†μ–΄ μš”μ•½ν•  수 μ—†μŠ΅λ‹ˆλ‹€."
 
59
  lines = []
60
  for row in rows:
61
  date = row["AA_YMD"]
@@ -63,11 +66,26 @@ def summarize_schedule(rows, school_name, year):
63
  event = row["EVENT_NM"]
64
  lines.append(f"{dt}: {event}")
65
  text = "\n".join(lines)
 
66
  prompt = f"{school_name}κ°€ {year}년도에 κ°€μ§€λŠ” 학사일정은 λ‹€μŒκ³Ό κ°™μŠ΅λ‹ˆλ‹€:\n{text}\nμ£Όμš” 일정을 μš”μ•½ν•΄μ£Όμ„Έμš”."
67
- st.write("πŸ“€ μš”μ•½μ— μ „λ‹¬λœ ν”„λ‘¬ν”„νŠΈ:", prompt)
68
- result = llm([{"role": "user", "content": prompt}])
69
- st.write("πŸ“₯ λͺ¨λΈ 생성 κ²°κ³Ό:", result)
70
- return result[0]["generated_text"].replace(prompt, "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # βœ… μ§€μ—­/학ꡐ/년도/μ›” 선택 UI
73
  region_options = {
@@ -83,6 +101,7 @@ with st.form("query_form"):
83
  month = st.selectbox("μ›”", options=list(range(1, 13)), index=6)
84
  submitted = st.form_submit_button("πŸ“… 학사일정 뢈러였기")
85
 
 
86
  if submitted:
87
  with st.spinner("일정 λΆˆλŸ¬μ˜€λŠ” 쀑..."):
88
  api_key = os.environ.get("NEIS_API_KEY", "a69e08342c8947b4a52cd72789a5ecaf")
@@ -94,7 +113,6 @@ if submitted:
94
  if not schedule_rows:
95
  st.info("ν•΄λ‹Ή 쑰건의 학사일정이 μ—†μŠ΅λ‹ˆλ‹€.")
96
  else:
97
- # βœ… 일정 좜λ ₯용 FullCalendar 생성
98
  events = [
99
  {
100
  "title": row["EVENT_NM"],
@@ -103,7 +121,6 @@ if submitted:
103
  for row in schedule_rows
104
  if "AA_YMD" in row and "EVENT_NM" in row
105
  ]
106
- st.write("πŸ“… FullCalendar에 전달할 events:", events)
107
  event_json = json.dumps(events, ensure_ascii=False)
108
 
109
  st.components.v1.html(f"""
@@ -130,10 +147,9 @@ if submitted:
130
  </html>
131
  """, height=650)
132
 
133
- # βœ… μš”μ•½ 생성 λ²„νŠΌ μΆ”κ°€
134
  with st.expander("✨ 1λ…„μΉ˜ μš”μ•½ 보기", expanded=False):
135
  if st.button("πŸ€– μš”μ•½ μƒμ„±ν•˜κΈ°"):
136
- with st.spinner("Gemma λͺ¨λΈμ΄ μš”μ•½ 쀑..."):
137
  summary = summarize_schedule(schedule_rows, school_name, year)
138
  st.success("μš”μ•½ μ™„λ£Œ!")
139
  st.markdown(f"**{school_name} {year}λ…„ {month}μ›” 일정 μš”μ•½:**\n\n{summary}")
 
3
  import requests
4
  import streamlit as st
5
  from datetime import datetime
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import torch
8
 
9
+ # βœ… μ•ˆμ „ν•œ μΊμ‹œ 경둜 μ„€μ •
10
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
11
  os.environ["HF_HOME"] = "/tmp/hf_cache"
12
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
 
16
  st.title("πŸ“… 학사일정 μΊ˜λ¦°λ” + AI μš”μ•½")
17
  st.markdown("NEIS APIμ—μ„œ 학사일정을 뢈러였고 FullCalendar둜 μ‹œκ°ν™”ν•©λ‹ˆλ‹€.")
18
 
19
+ # βœ… λͺ¨λΈ λ‘œλ”© ν•¨μˆ˜ (skt/A.X-4.0-Light)
 
 
 
 
 
20
  @st.cache_resource
21
  def load_model():
22
  token = os.environ.get("HUGGINGFACE_TOKEN")
23
+ model_id = "skt/A.X-4.0-Light"
24
  cache_dir = "/tmp/hf_cache"
25
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token, cache_dir=cache_dir)
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id,
29
+ use_auth_token=token,
30
+ torch_dtype=torch.bfloat16,
31
+ device_map="auto",
32
+ cache_dir=cache_dir
33
+ )
34
+ model.eval()
35
+ return tokenizer, model
36
 
37
+ tokenizer, model = load_model()
38
 
39
  # βœ… 학ꡐ 정보 κ°€μ Έμ˜€κΈ°
40
  def get_school_info(region_code, school_name, api_key):
 
44
  school = data.get("schoolInfo", [{}])[1].get("row", [{}])[0]
45
  return school.get("SD_SCHUL_CODE"), school.get("ATPT_OFCDC_SC_CODE")
46
 
47
+ # βœ… 학사일정 κ°€μ Έμ˜€κΈ°
48
  def get_schedule(region_code, school_code, year, month, api_key):
49
  from_ymd = f"{year}{month:02}01"
50
  to_ymd = f"{year}{month:02}31"
 
52
  res = requests.get(url)
53
  data = res.json()
54
  rows = data.get("SchoolSchedule", [{}])[1].get("row", [])
 
55
  return rows
56
 
57
  # βœ… μš”μ•½ 생성
58
  def summarize_schedule(rows, school_name, year):
59
  if not rows:
60
  return "일정이 μ—†μ–΄ μš”μ•½ν•  수 μ—†μŠ΅λ‹ˆλ‹€."
61
+
62
  lines = []
63
  for row in rows:
64
  date = row["AA_YMD"]
 
66
  event = row["EVENT_NM"]
67
  lines.append(f"{dt}: {event}")
68
  text = "\n".join(lines)
69
+
70
  prompt = f"{school_name}κ°€ {year}년도에 κ°€μ§€λŠ” 학사일정은 λ‹€μŒκ³Ό κ°™μŠ΅λ‹ˆλ‹€:\n{text}\nμ£Όμš” 일정을 μš”μ•½ν•΄μ£Όμ„Έμš”."
71
+
72
+ messages = [
73
+ {"role": "system", "content": "당신은 학사일정을 μš”μ•½ν•΄μ£ΌλŠ” AIμž…λ‹ˆλ‹€."},
74
+ {"role": "user", "content": prompt}
75
+ ]
76
+
77
+ input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
78
+
79
+ with torch.no_grad():
80
+ output = model.generate(
81
+ input_ids,
82
+ max_new_tokens=256,
83
+ do_sample=False,
84
+ )
85
+
86
+ len_prompt = input_ids.shape[1]
87
+ response = tokenizer.decode(output[0][len_prompt:], skip_special_tokens=True).strip()
88
+ return response
89
 
90
  # βœ… μ§€μ—­/학ꡐ/년도/μ›” 선택 UI
91
  region_options = {
 
101
  month = st.selectbox("μ›”", options=list(range(1, 13)), index=6)
102
  submitted = st.form_submit_button("πŸ“… 학사일정 뢈러였기")
103
 
104
+ # βœ… 제좜 처리
105
  if submitted:
106
  with st.spinner("일정 λΆˆλŸ¬μ˜€λŠ” 쀑..."):
107
  api_key = os.environ.get("NEIS_API_KEY", "a69e08342c8947b4a52cd72789a5ecaf")
 
113
  if not schedule_rows:
114
  st.info("ν•΄λ‹Ή 쑰건의 학사일정이 μ—†μŠ΅λ‹ˆλ‹€.")
115
  else:
 
116
  events = [
117
  {
118
  "title": row["EVENT_NM"],
 
121
  for row in schedule_rows
122
  if "AA_YMD" in row and "EVENT_NM" in row
123
  ]
 
124
  event_json = json.dumps(events, ensure_ascii=False)
125
 
126
  st.components.v1.html(f"""
 
147
  </html>
148
  """, height=650)
149
 
 
150
  with st.expander("✨ 1λ…„μΉ˜ μš”μ•½ 보기", expanded=False):
151
  if st.button("πŸ€– μš”μ•½ μƒμ„±ν•˜κΈ°"):
152
+ with st.spinner("λͺ¨λΈμ΄ μš”μ•½ 쀑..."):
153
  summary = summarize_schedule(schedule_rows, school_name, year)
154
  st.success("μš”μ•½ μ™„λ£Œ!")
155
  st.markdown(f"**{school_name} {year}λ…„ {month}μ›” 일정 μš”μ•½:**\n\n{summary}")