Spaces:
Runtime error
Runtime error
| /// Multi shard Client | |
| use crate::Result; | |
| use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; | |
| use futures::future::join_all; | |
| use tonic::transport::Uri; | |
| use tracing::instrument; | |
| /// Text Generation Inference gRPC multi client | |
| pub struct ShardedClient { | |
| clients: Vec<Client>, | |
| } | |
| impl ShardedClient { | |
| fn new(clients: Vec<Client>) -> Self { | |
| Self { clients } | |
| } | |
| /// Create a new ShardedClient from a master client. The master client will communicate with | |
| /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. | |
| async fn from_master_client(mut master_client: Client) -> Result<Self> { | |
| // Get all uris/unix sockets from the master client | |
| let uris = master_client.service_discovery().await?; | |
| let futures = uris.into_iter().map(Client::connect_uds); | |
| let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect(); | |
| Ok(Self::new(clients?)) | |
| } | |
| /// Returns a client connected to the given uri | |
| pub async fn connect(uri: Uri) -> Result<Self> { | |
| let master_client = Client::connect(uri).await?; | |
| Self::from_master_client(master_client).await | |
| } | |
| /// Returns a client connected to the given unix socket | |
| pub async fn connect_uds(path: String) -> Result<Self> { | |
| let master_client = Client::connect_uds(path).await?; | |
| Self::from_master_client(master_client).await | |
| } | |
| /// Get the model info | |
| pub async fn info(&mut self) -> Result<ShardInfo> { | |
| let futures: Vec<_> = self | |
| .clients | |
| .iter_mut() | |
| .map(|client| client.info()) | |
| .collect(); | |
| join_all(futures).await.pop().unwrap() | |
| } | |
| /// GRPC health check | |
| pub async fn health(&mut self) -> Result<HealthResponse> { | |
| let futures: Vec<_> = self | |
| .clients | |
| .iter_mut() | |
| .map(|client| client.health()) | |
| .collect(); | |
| join_all(futures).await.pop().unwrap() | |
| } | |
| /// Clear the past generations cache | |
| pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> { | |
| let futures: Vec<_> = self | |
| .clients | |
| .iter_mut() | |
| .map(|client| client.clear_cache(batch_id)) | |
| .collect(); | |
| join_all(futures).await.into_iter().collect() | |
| } | |
| /// Filter a cached batch | |
| pub async fn filter_batch( | |
| &mut self, | |
| batch_id: u64, | |
| keep_requests: Vec<Request>, | |
| ) -> Result<Option<Batch>> { | |
| let futures: Vec<_> = self | |
| .clients | |
| .iter_mut() | |
| .map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone()))) | |
| .collect(); | |
| // all shards return the same message | |
| join_all(futures).await.pop().unwrap() | |
| } | |
| /// Generate one token for each request in the given batch | |
| /// | |
| /// Returns Generation for each request in batch | |
| /// and the next cached batch | |
| pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { | |
| let futures: Vec<_> = self | |
| .clients | |
| .iter_mut() | |
| .map(|client| Box::pin(client.prefill(batch.clone()))) | |
| .collect(); | |
| // all shards return the same message | |
| join_all(futures).await.pop().unwrap() | |
| } | |
| /// Generate one token for each request in the given cached batches | |
| /// | |
| /// Returns Generation for each request in batches | |
| /// and the next cached batch | |
| pub async fn decode( | |
| &mut self, | |
| batches: Vec<Batch>, | |
| ) -> Result<(Vec<Generation>, Option<Batch>)> { | |
| let futures: Vec<_> = self | |
| .clients | |
| .iter_mut() | |
| .map(|client| Box::pin(client.decode(batches.clone()))) | |
| .collect(); | |
| // all shards return the same message | |
| join_all(futures).await.pop().unwrap() | |
| } | |
| } | |