Spaces:
Runtime error
Runtime error
Benjamin Bossan
commited on
Commit
·
1643735
1
Parent(s):
3efe4b4
Blacken
Browse files- src/gistillery/ml.py +19 -6
- src/gistillery/worker.py +3 -1
src/gistillery/ml.py
CHANGED
|
@@ -32,7 +32,9 @@ class Processor(abc.ABC):
|
|
| 32 |
|
| 33 |
|
| 34 |
class Summarizer(abc.ABC):
|
| 35 |
-
def __init__(
|
|
|
|
|
|
|
| 36 |
raise NotImplementedError
|
| 37 |
|
| 38 |
def get_name(self) -> str:
|
|
@@ -44,7 +46,9 @@ class Summarizer(abc.ABC):
|
|
| 44 |
|
| 45 |
|
| 46 |
class Tagger(abc.ABC):
|
| 47 |
-
def __init__(
|
|
|
|
|
|
|
| 48 |
raise NotImplementedError
|
| 49 |
|
| 50 |
def get_name(self) -> str:
|
|
@@ -90,7 +94,9 @@ class MlRegistry:
|
|
| 90 |
|
| 91 |
|
| 92 |
class HfTransformersSummarizer(Summarizer):
|
| 93 |
-
def __init__(
|
|
|
|
|
|
|
| 94 |
self.model_name = model_name
|
| 95 |
self.model = model
|
| 96 |
self.tokenizer = tokenizer
|
|
@@ -101,7 +107,9 @@ class HfTransformersSummarizer(Summarizer):
|
|
| 101 |
def __call__(self, x: str) -> str:
|
| 102 |
text = self.template.format(x)
|
| 103 |
inputs = self.tokenizer(text, return_tensors="pt")
|
| 104 |
-
outputs = self.model.generate(
|
|
|
|
|
|
|
| 105 |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 106 |
assert isinstance(output, str)
|
| 107 |
return output
|
|
@@ -111,7 +119,9 @@ class HfTransformersSummarizer(Summarizer):
|
|
| 111 |
|
| 112 |
|
| 113 |
class HfTransformersTagger(Tagger):
|
| 114 |
-
def __init__(
|
|
|
|
|
|
|
| 115 |
self.model_name = model_name
|
| 116 |
self.model = model
|
| 117 |
self.tokenizer = tokenizer
|
|
@@ -132,7 +142,9 @@ class HfTransformersTagger(Tagger):
|
|
| 132 |
def __call__(self, x: str) -> list[str]:
|
| 133 |
text = self.template.format(x)
|
| 134 |
inputs = self.tokenizer(text, return_tensors="pt")
|
| 135 |
-
outputs = self.model.generate(
|
|
|
|
|
|
|
| 136 |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 137 |
tags = self._extract_tags(output)
|
| 138 |
return tags
|
|
@@ -171,6 +183,7 @@ class DefaultUrlProcessor(Processor):
|
|
| 171 |
text = self.template.format(url=self.url, content=text)
|
| 172 |
return text
|
| 173 |
|
|
|
|
| 174 |
# class ProcessorRegistry:
|
| 175 |
# def __init__(self) -> None:
|
| 176 |
# self.registry: list[Processor] = []
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
class Summarizer(abc.ABC):
|
| 35 |
+
def __init__(
|
| 36 |
+
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
| 37 |
+
) -> None:
|
| 38 |
raise NotImplementedError
|
| 39 |
|
| 40 |
def get_name(self) -> str:
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
class Tagger(abc.ABC):
|
| 49 |
+
def __init__(
|
| 50 |
+
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
| 51 |
+
) -> None:
|
| 52 |
raise NotImplementedError
|
| 53 |
|
| 54 |
def get_name(self) -> str:
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
class HfTransformersSummarizer(Summarizer):
|
| 97 |
+
def __init__(
|
| 98 |
+
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
| 99 |
+
) -> None:
|
| 100 |
self.model_name = model_name
|
| 101 |
self.model = model
|
| 102 |
self.tokenizer = tokenizer
|
|
|
|
| 107 |
def __call__(self, x: str) -> str:
|
| 108 |
text = self.template.format(x)
|
| 109 |
inputs = self.tokenizer(text, return_tensors="pt")
|
| 110 |
+
outputs = self.model.generate(
|
| 111 |
+
**inputs, generation_config=self.generation_config
|
| 112 |
+
)
|
| 113 |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 114 |
assert isinstance(output, str)
|
| 115 |
return output
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
class HfTransformersTagger(Tagger):
|
| 122 |
+
def __init__(
|
| 123 |
+
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
| 124 |
+
) -> None:
|
| 125 |
self.model_name = model_name
|
| 126 |
self.model = model
|
| 127 |
self.tokenizer = tokenizer
|
|
|
|
| 142 |
def __call__(self, x: str) -> list[str]:
|
| 143 |
text = self.template.format(x)
|
| 144 |
inputs = self.tokenizer(text, return_tensors="pt")
|
| 145 |
+
outputs = self.model.generate(
|
| 146 |
+
**inputs, generation_config=self.generation_config
|
| 147 |
+
)
|
| 148 |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 149 |
tags = self._extract_tags(output)
|
| 150 |
return tags
|
|
|
|
| 183 |
text = self.template.format(url=self.url, content=text)
|
| 184 |
return text
|
| 185 |
|
| 186 |
+
|
| 187 |
# class ProcessorRegistry:
|
| 188 |
# def __init__(self) -> None:
|
| 189 |
# self.registry: list[Processor] = []
|
src/gistillery/worker.py
CHANGED
|
@@ -122,7 +122,9 @@ def load_mlregistry(model_name: str) -> MlRegistry:
|
|
| 122 |
# increase the temperature to make the model more creative
|
| 123 |
config_tagger.temperature = 1.5
|
| 124 |
|
| 125 |
-
summarizer = HfTransformersSummarizer(
|
|
|
|
|
|
|
| 126 |
tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
|
| 127 |
|
| 128 |
registry = MlRegistry()
|
|
|
|
| 122 |
# increase the temperature to make the model more creative
|
| 123 |
config_tagger.temperature = 1.5
|
| 124 |
|
| 125 |
+
summarizer = HfTransformersSummarizer(
|
| 126 |
+
model_name, model, tokenizer, config_summarizer
|
| 127 |
+
)
|
| 128 |
tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
|
| 129 |
|
| 130 |
registry = MlRegistry()
|