Upload folder using huggingface_hub
Browse files
custom_generate/generate.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Union
|
2 |
import torch
|
3 |
from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
4 |
from transformers.generation.utils import (
|
@@ -15,6 +15,9 @@ import logging
|
|
15 |
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
16 |
from .beam_search import ConstrainedBeamSearchScorer
|
17 |
|
|
|
|
|
|
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
def _constrained_beam_search(
|
|
|
1 |
+
from typing import Union, Optional, TYPE_CHECKING
|
2 |
import torch
|
3 |
from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
4 |
from transformers.generation.utils import (
|
|
|
15 |
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
16 |
from .beam_search import ConstrainedBeamSearchScorer
|
17 |
|
18 |
+
if TYPE_CHECKING:
|
19 |
+
from transformers.generation.streamers import BaseStreamer
|
20 |
+
|
21 |
logger = logging.getLogger(__name__)
|
22 |
|
23 |
def _constrained_beam_search(
|