cutechicken commited on
Commit
a908cb3
ยท
verified ยท
1 Parent(s): 9a66aa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -33,21 +33,26 @@ def get_embeddings(text, model, tokenizer):
33
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
- # ๋งˆ์ง€๋ง‰ ํžˆ๋“  ์Šคํ…Œ์ดํŠธ์˜ ํ‰๊ท ์„ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ์‚ฌ์šฉ
37
- embeddings = outputs.last_hidden_state.mean(dim=1)
 
38
  return embeddings
39
 
40
  # ๋ฐ์ดํ„ฐ์…‹์˜ ์งˆ๋ฌธ๋“ค์„ ์ž„๋ฒ ๋”ฉ
41
- questions = wiki_dataset['train']['question'][:10000] # ์ฒ˜์Œ 10000๊ฐœ๋งŒ ์‚ฌ์šฉ
 
42
  question_embeddings = []
43
- batch_size = 32
44
 
45
  for i in range(0, len(questions), batch_size):
46
  batch = questions[i:i+batch_size]
47
  batch_embeddings = get_embeddings(batch, model, tokenizer)
48
- question_embeddings.append(batch_embeddings)
 
 
49
 
50
  question_embeddings = torch.cat(question_embeddings, dim=0)
 
51
 
52
  def find_relevant_context(query, top_k=3):
53
  # ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ
@@ -56,7 +61,7 @@ def find_relevant_context(query, top_k=3):
56
  # ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
57
  similarities = cosine_similarity(
58
  query_embedding.cpu().numpy(),
59
- question_embeddings.cpu().numpy()
60
  )[0]
61
 
62
  # ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์งˆ๋ฌธ๋“ค์˜ ์ธ๋ฑ์Šค
 
33
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
+ # hidden states์˜ ํ‰๊ท ์„ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ์‚ฌ์šฉ
37
+ hidden_states = outputs[0] # ๋ชจ๋ธ์˜ ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด ์ถœ๋ ฅ
38
+ embeddings = hidden_states.mean(dim=1)
39
  return embeddings
40
 
41
  # ๋ฐ์ดํ„ฐ์…‹์˜ ์งˆ๋ฌธ๋“ค์„ ์ž„๋ฒ ๋”ฉ
42
+ print("์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ์‹œ์ž‘...")
43
+ questions = wiki_dataset['train']['question'][:1000] # ์ฒ˜์Œ 1000๊ฐœ๋งŒ ์‚ฌ์šฉ (ํ…Œ์ŠคํŠธ์šฉ)
44
  question_embeddings = []
45
+ batch_size = 8 # ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ ์ค„์ž„
46
 
47
  for i in range(0, len(questions), batch_size):
48
  batch = questions[i:i+batch_size]
49
  batch_embeddings = get_embeddings(batch, model, tokenizer)
50
+ question_embeddings.append(batch_embeddings.cpu())
51
+ if i % 100 == 0:
52
+ print(f"Processed {i}/{len(questions)} questions")
53
 
54
  question_embeddings = torch.cat(question_embeddings, dim=0)
55
+ print("์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ์™„๋ฃŒ")
56
 
57
  def find_relevant_context(query, top_k=3):
58
  # ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ
 
61
  # ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
62
  similarities = cosine_similarity(
63
  query_embedding.cpu().numpy(),
64
+ question_embeddings.numpy()
65
  )[0]
66
 
67
  # ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์งˆ๋ฌธ๋“ค์˜ ์ธ๋ฑ์Šค