cbensimon HF Staff commited on
Commit
d00873b
·
1 Parent(s): 3af4a0e

More cleanup

Browse files
optimization.py CHANGED
@@ -9,8 +9,8 @@ import spaces
9
  import torch
10
  from torch.utils._pytree import tree_map_only
11
 
12
- from pipeline_utils import capture_component_call
13
- from zerogpu import aoti_compile
14
 
15
 
16
  P = ParamSpec('P')
 
9
  import torch
10
  from torch.utils._pytree import tree_map_only
11
 
12
+ from optimization_utils import capture_component_call
13
+ from optimization_utils import aoti_compile
14
 
15
 
16
  P = ParamSpec('P')
zerogpu.py → optimization_utils.py RENAMED
@@ -1,9 +1,11 @@
1
  """
2
  """
 
3
  from contextvars import ContextVar
4
  from io import BytesIO
5
  from typing import Any
6
  from typing import cast
 
7
 
8
  import torch
9
  from torch._inductor.package.package import package_aoti
@@ -60,3 +62,35 @@ def aoti_compile(
60
  package_aoti(archive_file, files)
61
  weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
62
  return ZeroGPUCompiledModel(archive_file, weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  """
3
+ import contextlib
4
  from contextvars import ContextVar
5
  from io import BytesIO
6
  from typing import Any
7
  from typing import cast
8
+ from unittest.mock import patch
9
 
10
  import torch
11
  from torch._inductor.package.package import package_aoti
 
62
  package_aoti(archive_file, files)
63
  weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
64
  return ZeroGPUCompiledModel(archive_file, weights)
65
+
66
+
67
+ @contextlib.contextmanager
68
+ def capture_component_call(
69
+ pipeline: Any,
70
+ component_name: str,
71
+ component_method='forward',
72
+ ):
73
+
74
+ class CapturedCallException(Exception):
75
+ def __init__(self, *args, **kwargs):
76
+ super().__init__()
77
+ self.args = args
78
+ self.kwargs = kwargs
79
+
80
+ class CapturedCall:
81
+ def __init__(self):
82
+ self.args: tuple[Any, ...] = ()
83
+ self.kwargs: dict[str, Any] = {}
84
+
85
+ component = getattr(pipeline, component_name)
86
+ captured_call = CapturedCall()
87
+
88
+ def capture_call(*args, **kwargs):
89
+ raise CapturedCallException(*args, **kwargs)
90
+
91
+ with patch.object(component, component_method, new=capture_call):
92
+ try:
93
+ yield captured_call
94
+ except CapturedCallException as e:
95
+ captured_call.args = e.args
96
+ captured_call.kwargs = e.kwargs
pipeline_utils.py DELETED
@@ -1,40 +0,0 @@
1
- """
2
- """
3
-
4
- import contextlib
5
- from unittest.mock import patch
6
-
7
- from typing import Any
8
-
9
-
10
- class CapturedCallException(Exception):
11
- def __init__(self, *args, **kwargs):
12
- super().__init__()
13
- self.args = args
14
- self.kwargs = kwargs
15
-
16
-
17
- class CapturedCall:
18
- def __init__(self):
19
- self.args: tuple[Any, ...] = ()
20
- self.kwargs: dict[str, Any] = {}
21
-
22
-
23
- @contextlib.contextmanager
24
- def capture_component_call(
25
- pipeline: Any,
26
- component_name: str,
27
- component_method='forward',
28
- ):
29
- component = getattr(pipeline, component_name)
30
- captured_call = CapturedCall()
31
-
32
- def capture_call(*args, **kwargs):
33
- raise CapturedCallException(*args, **kwargs)
34
-
35
- with patch.object(component, component_method, new=capture_call):
36
- try:
37
- yield captured_call
38
- except CapturedCallException as e:
39
- captured_call.args = e.args
40
- captured_call.kwargs = e.kwargs