Spaces:
Runtime error
Runtime error
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| from utils import ( | |
| get_preprocess_function, | |
| get_utterance_processing_functions, | |
| byt5_decode_batch, | |
| consistent, | |
| ) | |
| from utils import ( | |
| PROGRAM_SPECIAL_TOKEN, | |
| UTTERANCES_SPECIAL_TOKEN, | |
| GT_PROGRAM_SPECIAL_TOKEN, | |
| ) | |
| from greenery import parse | |
| from greenery.parse import NoMatch | |
| import numpy as np | |
| import torch | |
| class Agent: | |
| def __init__( | |
| self, | |
| model_path: str, | |
| gen_config: dict, | |
| inference_batch_size: int = 1, | |
| device=None, | |
| ): | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.gen_config = GenerationConfig(**gen_config) | |
| self.inference_batch_size = inference_batch_size | |
| class ListenerOutput: | |
| programs: List[List[str]] | |
| idx: Optional[List[List[int]]] = None | |
| decoded: Optional[List[List[str]]] = None | |
| decoded_scores: Optional[List[List[float]]] = None | |
| pruned: Optional[List[List[str]]] = None | |
| class Listener(Agent): | |
| def __init__( | |
| self, | |
| model_path, | |
| gen_config, | |
| inference_batch_size=4, | |
| label_pos="suffix", | |
| idx: bool = True, | |
| program_special_token=PROGRAM_SPECIAL_TOKEN, | |
| utterances_special_token=UTTERANCES_SPECIAL_TOKEN, | |
| device=None, | |
| ): | |
| super().__init__(model_path, gen_config, inference_batch_size, device) | |
| self.label_pos = label_pos | |
| self.idx = idx | |
| self.program_special_token = program_special_token | |
| self.utterances_special_token = utterances_special_token | |
| self.utterances_to_string, self.string_to_utterances = ( | |
| get_utterance_processing_functions( | |
| label_pos, idx, separator=utterances_special_token | |
| ) | |
| ) | |
| self.device = self.model.device | |
| def synthesize(self, context, return_scores=False, enforce_consistency=True): | |
| # If context is a list of utterances, convert to string | |
| if isinstance(context[0], list): | |
| context_str = list(map(self.utterances_to_string, context)) | |
| else: | |
| context_str = context | |
| context_tokens = self.tokenizer( | |
| [ | |
| ( | |
| f"{self.utterances_special_token}{c}" | |
| if not c.startswith(self.utterances_special_token) | |
| else c | |
| ) | |
| for c in context_str | |
| ], | |
| return_tensors="pt", | |
| padding=True, | |
| ).to(self.device) | |
| decoder_inputs = self.tokenizer( | |
| [self.program_special_token for _ in context], | |
| return_tensors="pt", | |
| add_special_tokens=False, | |
| ).to(self.device) | |
| outputs = self.model.generate( | |
| **context_tokens, | |
| decoder_input_ids=decoder_inputs.input_ids, | |
| generation_config=self.gen_config, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| decoded_batch = byt5_decode_batch( | |
| outputs.sequences.reshape( | |
| (len(context), -1, outputs.sequences.shape[-1]) | |
| ).tolist(), | |
| skip_position_token=True, | |
| skip_special_tokens=True, | |
| ) | |
| consistent_programs = [] | |
| idxs = [] | |
| for decoded, ctx in zip(decoded_batch, context): | |
| cp = [] | |
| idx = [] | |
| for i, p in enumerate(decoded): | |
| if enforce_consistency: | |
| if consistent(p, ctx): | |
| cp.append(p) | |
| idx.append(i) | |
| else: | |
| cp.append(p) | |
| idx.append(i) | |
| consistent_programs.append(cp) | |
| idxs.append(idx) | |
| logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1) | |
| gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze( | |
| -1 | |
| ) | |
| gen_probs.masked_fill_(gen_probs.isinf(), 0) | |
| scores = gen_probs.sum(-1) | |
| n_decoded = scores.shape[0] | |
| n_seq = n_decoded // len(context) | |
| scores = scores.reshape((len(context), n_seq)) | |
| scores_list = scores.tolist() | |
| if return_scores: | |
| return ListenerOutput(consistent_programs, idxs, decoded_batch, scores_list) | |
| else: | |
| return ListenerOutput(consistent_programs) | |