Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -33,21 +33,26 @@ def get_embeddings(text, model, tokenizer):
|
|
33 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
|
34 |
with torch.no_grad():
|
35 |
outputs = model(**inputs)
|
36 |
-
#
|
37 |
-
|
|
|
38 |
return embeddings
|
39 |
|
40 |
# ๋ฐ์ดํฐ์
์ ์ง๋ฌธ๋ค์ ์๋ฒ ๋ฉ
|
41 |
-
|
|
|
42 |
question_embeddings = []
|
43 |
-
batch_size =
|
44 |
|
45 |
for i in range(0, len(questions), batch_size):
|
46 |
batch = questions[i:i+batch_size]
|
47 |
batch_embeddings = get_embeddings(batch, model, tokenizer)
|
48 |
-
question_embeddings.append(batch_embeddings)
|
|
|
|
|
49 |
|
50 |
question_embeddings = torch.cat(question_embeddings, dim=0)
|
|
|
51 |
|
52 |
def find_relevant_context(query, top_k=3):
|
53 |
# ์ฟผ๋ฆฌ ์๋ฒ ๋ฉ
|
@@ -56,7 +61,7 @@ def find_relevant_context(query, top_k=3):
|
|
56 |
# ์ฝ์ฌ์ธ ์ ์ฌ๋ ๊ณ์ฐ
|
57 |
similarities = cosine_similarity(
|
58 |
query_embedding.cpu().numpy(),
|
59 |
-
question_embeddings.
|
60 |
)[0]
|
61 |
|
62 |
# ๊ฐ์ฅ ์ ์ฌํ ์ง๋ฌธ๋ค์ ์ธ๋ฑ์ค
|
|
|
33 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
|
34 |
with torch.no_grad():
|
35 |
outputs = model(**inputs)
|
36 |
+
# hidden states์ ํ๊ท ์ ์๋ฒ ๋ฉ์ผ๋ก ์ฌ์ฉ
|
37 |
+
hidden_states = outputs[0] # ๋ชจ๋ธ์ ๋ง์ง๋ง ๋ ์ด์ด ์ถ๋ ฅ
|
38 |
+
embeddings = hidden_states.mean(dim=1)
|
39 |
return embeddings
|
40 |
|
41 |
# ๋ฐ์ดํฐ์
์ ์ง๋ฌธ๋ค์ ์๋ฒ ๋ฉ
|
42 |
+
print("์๋ฒ ๋ฉ ์์ฑ ์์...")
|
43 |
+
questions = wiki_dataset['train']['question'][:1000] # ์ฒ์ 1000๊ฐ๋ง ์ฌ์ฉ (ํ
์คํธ์ฉ)
|
44 |
question_embeddings = []
|
45 |
+
batch_size = 8 # ๋ฐฐ์น ์ฌ์ด์ฆ ์ค์
|
46 |
|
47 |
for i in range(0, len(questions), batch_size):
|
48 |
batch = questions[i:i+batch_size]
|
49 |
batch_embeddings = get_embeddings(batch, model, tokenizer)
|
50 |
+
question_embeddings.append(batch_embeddings.cpu())
|
51 |
+
if i % 100 == 0:
|
52 |
+
print(f"Processed {i}/{len(questions)} questions")
|
53 |
|
54 |
question_embeddings = torch.cat(question_embeddings, dim=0)
|
55 |
+
print("์๋ฒ ๋ฉ ์์ฑ ์๋ฃ")
|
56 |
|
57 |
def find_relevant_context(query, top_k=3):
|
58 |
# ์ฟผ๋ฆฌ ์๋ฒ ๋ฉ
|
|
|
61 |
# ์ฝ์ฌ์ธ ์ ์ฌ๋ ๊ณ์ฐ
|
62 |
similarities = cosine_similarity(
|
63 |
query_embedding.cpu().numpy(),
|
64 |
+
question_embeddings.numpy()
|
65 |
)[0]
|
66 |
|
67 |
# ๊ฐ์ฅ ์ ์ฌํ ์ง๋ฌธ๋ค์ ์ธ๋ฑ์ค
|