Spaces:
Runtime error
Runtime error
# import firebase_admin | |
# from firebase_admin import credentials | |
# from firebase_admin import firestore | |
import io | |
from fastapi import FastAPI, File, UploadFile | |
from werkzeug.utils import secure_filename | |
# import speech_recognition as sr | |
import subprocess | |
import os | |
import requests | |
import random | |
import pandas as pd | |
from pydub import AudioSegment | |
from datetime import datetime | |
from datetime import date | |
import numpy as np | |
# from sklearn.ensemble import RandomForestRegressor | |
import shutil | |
import json | |
# from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline | |
from pydantic import BaseModel | |
from typing import Annotated | |
# from transformers import BertTokenizerFast, EncoderDecoderModel | |
import torch | |
import re | |
# from transformers import AutoTokenizer, T5ForConditionalGeneration | |
from fastapi import Form | |
from transformers import AutoModelForSequenceClassification | |
from transformers import TFAutoModelForSequenceClassification | |
from transformers import AutoTokenizer, AutoConfig | |
import numpy as np | |
from scipy.special import softmax | |
def preprocess(text): | |
new_text = [] | |
for t in text.split(" "): | |
t = '@user' if t.startswith('@') and len(t) > 1 else t | |
t = 'http' if t.startswith('http') else t | |
new_text.append(t) | |
return " ".join(new_text) | |
MODEL = f"cardiffnlp/twitter-roberta-base-sentiment-latest" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
config = AutoConfig.from_pretrained(MODEL) | |
# PT | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL) | |
class Query(BaseModel): | |
text: str | |
from fastapi import FastAPI, Request, Depends, UploadFile, File | |
from fastapi.exceptions import HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
# now = datetime.now() | |
# UPLOAD_FOLDER = '/files' | |
# ALLOWED_EXTENSIONS = {'txt', 'pdf', 'png', | |
# 'jpg', 'jpeg', 'gif', 'ogg', 'mp3', 'wav'} | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=['*'], | |
allow_credentials=True, | |
allow_methods=['*'], | |
allow_headers=['*'], | |
) | |
# cred = credentials.Certificate('key.json') | |
# app1 = firebase_admin.initialize_app(cred) | |
# db = firestore.client() | |
# data_frame = pd.read_csv('data.csv') | |
async def startup_event(): | |
print("on startup") | |
async def get_answer(q: Query ): | |
text = q.text | |
text = preprocess(text) | |
encoded_input = tokenizer(text, return_tensors='pt') | |
output = model(**encoded_input) | |
scores = output[0][0].detach().numpy() | |
scores = softmax(scores) | |
# print(scores) | |
ranking = np.argsort(scores) | |
ranking = ranking[::-1] | |
dict={} | |
for i in range(scores.shape[0]): | |
l = config.id2label[ranking[i]] | |
s = scores[ranking[i]] | |
dict[str(l)]= str(s) | |
return dict | |
return "hello" | |