oucgc1996 commited on
Commit
6ecb800
·
verified ·
1 Parent(s): d4eb073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -0
app.py CHANGED
@@ -1,3 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def CTXGen(τ, g_num, length_range, progress=gr.Progress()):
2
  start, end = length_range
3
  X1 = "X"
 
1
+ import random
2
+ import torch
3
+ import gradio as gr
4
+ from gradio_rangeslider import RangeSlider
5
+ import pandas as pd
6
+ from utils import create_vocab, setup_seed
7
+ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
8
+ import time
9
+ seed = random.randint(0,100000)
10
+
11
+ setup_seed(seed)
12
+ device = torch.device("cpu")
13
+ vocab_mlm = create_vocab()
14
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
15
+ save_path = 'mlm-model-27.pt'
16
+ train_seqs = pd.read_csv('C0_seq.csv')
17
+ train_seq = train_seqs['Seq'].tolist()
18
+ model = torch.load(save_path, map_location=torch.device('cpu'))
19
+ model = model.to(device)
20
+
21
+ def temperature_sampling(logits, temperature):
22
+ logits = logits / temperature
23
+ probabilities = torch.softmax(logits, dim=-1)
24
+ sampled_token = torch.multinomial(probabilities, 1)
25
+ return sampled_token
26
+
27
  def CTXGen(τ, g_num, length_range, progress=gr.Progress()):
28
  start, end = length_range
29
  X1 = "X"