File size: 1,534 Bytes
d6f85ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from typing import List
import torch
from diffusers import ModularPipelineBlocks
from diffusers.modular_pipelines import PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
InputParam,
OutputParam,
)
from image_gen_aux import DepthPreprocessor
class DepthProcessorBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="depth_processor",
type_hint=DepthPreprocessor,
repo="depth-anything/Depth-Anything-V2-Large-hf",
)
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
required=True,
description="Image(s) to use to extract depth maps",
)
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"condition_image",
type_hint=torch.Tensor,
description="Depth Map(s) of input Image(s)",
),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
image = block_state.image
depth_map = components.depth_processor(image)
block_state.condition_image = depth_map.to(block_state.device)
self.set_block_state(state, block_state)
return components, state
|