File size: 896 Bytes
090a6eb
 
 
 
 
 
1b497f2
 
 
090a6eb
 
 
 
 
1b497f2
 
 
 
 
 
bd190f2
 
090a6eb
 
 
ed2c942
090a6eb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from pathlib import Path
from typing import List, Dict, Tuple

import seaborn as sns
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TextClassificationPipeline

sns.set_theme()

www_dir = Path(__file__).parent.resolve() / "www"

def server(input: Inputs, output: Outputs, session: Session):
    model_name = 'lincoln/flaubert-mlsum-topic-classification'
    
    loaded_tokenizer = AutoTokenizer.from_pretrained(model_name)
    loaded_model = AutoModelForSequenceClassification.from_pretrained(model_name)
    
    nlp = TextClassificationPipeline(model=loaded_model, tokenizer=loaded_tokenizer)
    result = nlp(Inputs, truncation=True)
    return result


app = App(
    Inputs = "Le Bayern Munich prend la grenadine.",
    server,
    static_assets=str(www_dir),
)