-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathcomfyui_mask_sequence_ops.py
93 lines (72 loc) · 2.24 KB
/
comfyui_mask_sequence_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
def register_node(identifier: str, display_name: str):
def decorator(cls):
NODE_CLASS_MAPPINGS[identifier] = cls
NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name
return cls
return decorator
@register_node("JWMaskSequenceFromMask", "Mask Sequence From Mask")
class _:
CATEGORY = "jamesWalker55"
INPUT_TYPES = lambda: {
"required": {
"mask": ("MASK",),
"batch_size": ("INT", {"default": 1, "min": 1, "step": 1}),
}
}
RETURN_TYPES = ("MASK_SEQUENCE",)
FUNCTION = "execute"
def execute(
self,
mask: torch.Tensor,
batch_size: int,
):
assert isinstance(mask, torch.Tensor)
assert isinstance(batch_size, int)
assert len(mask.shape) == 2
mask_seq = mask.reshape((1, 1, *mask.shape))
mask_seq = mask_seq.repeat(batch_size, 1, 1, 1)
return (mask_seq,)
@register_node("JWMaskSequenceJoin", "Join Mask Sequence")
class _:
CATEGORY = "jamesWalker55"
INPUT_TYPES = lambda: {
"required": {
"mask_sequence_1": ("MASK_SEQUENCE",),
"mask_sequence_2": ("MASK_SEQUENCE",),
}
}
RETURN_TYPES = ("MASK_SEQUENCE",)
FUNCTION = "execute"
def execute(
self,
mask_sequence_1: torch.Tensor,
mask_sequence_2: torch.Tensor,
):
assert isinstance(mask_sequence_1, torch.Tensor)
assert isinstance(mask_sequence_2, torch.Tensor)
mask_seq = torch.cat((mask_sequence_1, mask_sequence_2), dim=0)
return (mask_seq,)
@register_node("JWMaskSequenceApplyToLatent", "Apply Mask Sequence to Latent")
class _:
CATEGORY = "jamesWalker55"
INPUT_TYPES = lambda: {
"required": {
"samples": ("LATENT",),
"mask_sequence": ("MASK_SEQUENCE",),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "execute"
def execute(
self,
samples: dict,
mask_sequence: torch.Tensor,
):
assert isinstance(samples, dict)
assert isinstance(mask_sequence, torch.Tensor)
samples = samples.copy()
samples["noise_mask"] = mask_sequence
return (samples,)