Spaces:
Sleeping
Sleeping
Allow the user to specify a partially rewritten document.
Browse files
app.py
CHANGED
|
@@ -28,9 +28,10 @@ def get_model(model_name):
|
|
| 28 |
|
| 29 |
prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
|
| 30 |
doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
-
def get_spans_local(prompt, doc):
|
| 34 |
import torch
|
| 35 |
|
| 36 |
tokenizer = get_tokenizer(model_name)
|
|
@@ -46,8 +47,10 @@ def get_spans_local(prompt, doc):
|
|
| 46 |
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
|
| 47 |
assert len(tokenized_chat.shape) == 1
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Call the model
|
| 53 |
with torch.no_grad():
|
|
@@ -72,18 +75,22 @@ def get_spans_local(prompt, doc):
|
|
| 72 |
length_so_far += len(token)
|
| 73 |
return spans
|
| 74 |
|
| 75 |
-
def get_highlights_api(prompt, doc):
|
| 76 |
# Make a request to the API. prompt and doc are query parameters:
|
| 77 |
# https://tools.kenarnold.org/api/highlights?prompt=Rewrite%20this%20document&doc=This%20is%20a%20document
|
| 78 |
# The response is a JSON array
|
| 79 |
import requests
|
| 80 |
-
response = requests.get("https://tools.kenarnold.org/api/highlights", params=dict(prompt=prompt, doc=doc))
|
| 81 |
return response.json()['highlights']
|
| 82 |
|
| 83 |
if model_name == 'API':
|
| 84 |
-
spans = get_highlights_api(prompt, doc)
|
| 85 |
else:
|
| 86 |
-
spans = get_spans_local(prompt, doc)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
highest_loss = max(span['token_loss'] for span in spans[1:])
|
| 89 |
for span in spans:
|
|
@@ -99,6 +106,5 @@ for span in spans:
|
|
| 99 |
)
|
| 100 |
html_out = f"<p style=\"background: white;\">{html_out}</p>"
|
| 101 |
|
| 102 |
-
st.subheader("Rewritten document")
|
| 103 |
st.write(html_out, unsafe_allow_html=True)
|
| 104 |
st.write(pd.DataFrame(spans)[['token', 'token_loss', 'most_likely_token', 'loss_ratio']])
|
|
|
|
| 28 |
|
| 29 |
prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
|
| 30 |
doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
|
| 31 |
+
updated_doc = st.text_area("Updated Doc", help="Your edited document. Leave this blank to use your original document.")
|
| 32 |
|
| 33 |
|
| 34 |
+
def get_spans_local(prompt, doc, updated_doc):
|
| 35 |
import torch
|
| 36 |
|
| 37 |
tokenizer = get_tokenizer(model_name)
|
|
|
|
| 47 |
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
|
| 48 |
assert len(tokenized_chat.shape) == 1
|
| 49 |
|
| 50 |
+
if len(updated_doc.strip()) == 0:
|
| 51 |
+
updated_doc = doc
|
| 52 |
+
updated_doc_ids = tokenizer(updated_doc, return_tensors='pt')['input_ids'][0]
|
| 53 |
+
joined_ids = torch.cat([tokenized_chat, updated_doc_ids[1:]])
|
| 54 |
|
| 55 |
# Call the model
|
| 56 |
with torch.no_grad():
|
|
|
|
| 75 |
length_so_far += len(token)
|
| 76 |
return spans
|
| 77 |
|
| 78 |
+
def get_highlights_api(prompt, doc, updated_doc):
|
| 79 |
# Make a request to the API. prompt and doc are query parameters:
|
| 80 |
# https://tools.kenarnold.org/api/highlights?prompt=Rewrite%20this%20document&doc=This%20is%20a%20document
|
| 81 |
# The response is a JSON array
|
| 82 |
import requests
|
| 83 |
+
response = requests.get("https://tools.kenarnold.org/api/highlights", params=dict(prompt=prompt, doc=doc, updated_doc=updated_doc))
|
| 84 |
return response.json()['highlights']
|
| 85 |
|
| 86 |
if model_name == 'API':
|
| 87 |
+
spans = get_highlights_api(prompt, doc, updated_doc)
|
| 88 |
else:
|
| 89 |
+
spans = get_spans_local(prompt, doc, updated_doc)
|
| 90 |
+
|
| 91 |
+
if len(spans) < 2:
|
| 92 |
+
st.write("No spans found.")
|
| 93 |
+
st.stop()
|
| 94 |
|
| 95 |
highest_loss = max(span['token_loss'] for span in spans[1:])
|
| 96 |
for span in spans:
|
|
|
|
| 106 |
)
|
| 107 |
html_out = f"<p style=\"background: white;\">{html_out}</p>"
|
| 108 |
|
|
|
|
| 109 |
st.write(html_out, unsafe_allow_html=True)
|
| 110 |
st.write(pd.DataFrame(spans)[['token', 'token_loss', 'most_likely_token', 'loss_ratio']])
|