Spaces:
Sleeping
Sleeping
| import torch | |
| import requests | |
| from .config import TAG_NAMES, SPACE_URL | |
| from .globals import global_model, global_tokenizer | |
| def predict_single(text, hf_repo, backend="local", hf_token=None): | |
| if backend == "local": | |
| return _predict_local(text, hf_repo) | |
| elif backend == "hf": | |
| return _predict_hf_api(text, hf_token) | |
| else: | |
| raise ValueError(f"Unknown backend: {backend}") | |
| def _predict_local(text, hf_repo): | |
| global global_model, global_tokenizer | |
| # Lazy-loading to avoid slow startup | |
| if global_model is None: | |
| from .model import QwenClassifier | |
| from transformers import AutoTokenizer | |
| global_model = QwenClassifier.from_pretrained(hf_repo).eval() | |
| global_tokenizer = AutoTokenizer.from_pretrained(hf_repo) | |
| inputs = global_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| logits = global_model(**inputs) | |
| return _process_output(logits) | |
| def _predict_hf_api(text, hf_token=None): | |
| try: | |
| response = requests.post( | |
| f"{SPACE_URL}/predict", | |
| json={"text": text}, # This matches the Pydantic model | |
| headers={ | |
| "Authorization": f"Bearer {hf_token}", | |
| "Content-Type": "application/json" | |
| } if hf_token else {"Content-Type": "application/json"}, | |
| timeout=10 | |
| ) | |
| response.raise_for_status() # Raise HTTP errors | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| raise ValueError(f"API Error: {str(e)}\nResponse: {e.response.text if hasattr(e, 'response') else ''}") | |
| def _process_output(logits): | |
| probs = torch.sigmoid(logits) | |
| s = '' | |
| for tag, prob in zip(TAG_NAMES, probs[0]): | |
| if prob>0.5: | |
| s += f"{tag}({prob:.2f}), " | |
| return s[:-2] | |