| from typing import List | |
| import torch | |
| from diffusers import ModularPipelineBlocks | |
| from diffusers.modular_pipelines import PipelineState | |
| from diffusers.modular_pipelines.modular_pipeline_utils import ( | |
| ComponentSpec, | |
| InputParam, | |
| OutputParam, | |
| ) | |
| from image_gen_aux import DepthPreprocessor | |
| class DepthProcessorBlock(ModularPipelineBlocks): | |
| def expected_components(self): | |
| return [ | |
| ComponentSpec( | |
| name="depth_processor", | |
| type_hint=DepthPreprocessor, | |
| repo="depth-anything/Depth-Anything-V2-Large-hf", | |
| ) | |
| ] | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam( | |
| "image", | |
| required=True, | |
| description="Image(s) to use to extract depth maps", | |
| ) | |
| ] | |
| def intermediates_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam( | |
| "condition_image", | |
| type_hint=torch.Tensor, | |
| description="Depth Map(s) of input Image(s)", | |
| ), | |
| ] | |
| def __call__(self, components, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| image = block_state.image | |
| depth_map = components.depth_processor(image) | |
| block_state.condition_image = depth_map.to(block_state.device) | |
| self.set_block_state(state, block_state) | |
| return components, state | |