Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import time | |
| from abc import abstractmethod | |
| from typing import Any, Tuple, Union | |
| import gradio_client as grc | |
| import huggingface_hub | |
| from gradio_client.client import Job | |
| from gradio_client.utils import QueueError | |
| try: | |
| import langchain as lc | |
| LANGCHAIN_INSTALLED = True | |
| except (ModuleNotFoundError, ImportError): | |
| LANGCHAIN_INSTALLED = False | |
| class GradioTool: | |
| def __init__( | |
| self, | |
| name: str, | |
| description: str, | |
| src: str, | |
| hf_token: str | None = None, | |
| duplicate: bool = True, | |
| ) -> None: | |
| self.name = name | |
| self.description = description | |
| if hf_token and self._is_space(src) and duplicate: | |
| self.client = grc.Client.duplicate(from_id=src, hf_token=hf_token) | |
| self.src = self.client.space_id | |
| else: | |
| self.src = src | |
| self.client = grc.Client(self.src, hf_token=hf_token) | |
| self.description += ( | |
| "If the output is the string QUEUE_FULL you " | |
| "should try this action again and never stop until you " | |
| "don't see this message." | |
| ) | |
| self._block = None | |
| def _is_space(src: str) -> bool: | |
| try: | |
| huggingface_hub.get_space_runtime(src) | |
| return True | |
| except huggingface_hub.hf_api.RepositoryNotFoundError: | |
| return False | |
| def create_job(self, query: str) -> Job: | |
| pass | |
| def postprocess(self, output: Union[Tuple[Any], Any]) -> str: | |
| pass | |
| def run(self, query: str): | |
| job = self.create_job(query) | |
| while not job.done(): | |
| status = job.status() | |
| print(f"\nJob Status: {str(status.code)} eta: {status.eta}") | |
| time.sleep(30) | |
| try: | |
| output = self.postprocess(job.result()) | |
| except QueueError: | |
| output = "QUEUE_FULL" | |
| return output | |
| # Optional gradio functionalities | |
| def _block_input(self, gr) -> "gr.components.Component": | |
| return gr.Textbox() | |
| def _block_output(self, gr) -> "gr.components.Component": | |
| return gr.Textbox() | |
| def block_input(self) -> "gr.components.Component": | |
| try: | |
| import gradio as gr | |
| GRADIO_INSTALLED = True | |
| except (ModuleNotFoundError, ImportError): | |
| GRADIO_INSTALLED = False | |
| if not GRADIO_INSTALLED: | |
| raise ModuleNotFoundError("gradio must be installed to call block_input") | |
| else: | |
| return self._block_input(gr) | |
| def block_output(self) -> "gr.components.Component": | |
| try: | |
| import gradio as gr | |
| GRADIO_INSTALLED = True | |
| except (ModuleNotFoundError, ImportError): | |
| GRADIO_INSTALLED = False | |
| if not GRADIO_INSTALLED: | |
| raise ModuleNotFoundError("gradio must be installed to call block_output") | |
| else: | |
| return self._block_output(gr) | |
| def block(self): | |
| """Get the gradio Blocks of this tool for visualization.""" | |
| try: | |
| import gradio as gr | |
| except (ModuleNotFoundError, ImportError): | |
| raise ModuleNotFoundError("gradio must be installed to call block") | |
| if not self._block: | |
| self._block = gr.load(name=self.src, src="spaces") | |
| return self._block | |
| # Optional langchain functionalities | |
| def langchain(self) -> "langchain.agents.Tool": # type: ignore | |
| if not LANGCHAIN_INSTALLED: | |
| raise ModuleNotFoundError( | |
| "langchain must be installed to access langchain tool" | |
| ) | |
| return lc.agents.Tool( # type: ignore | |
| name=self.name, func=self.run, description=self.description | |
| ) | |
| def __repr__(self) -> str: | |
| return f"GradioTool(name={self.name}, src={self.src})" | |