from typing import List import torch from diffusers import ModularPipelineBlocks from diffusers.modular_pipelines import PipelineState from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, InputParam, OutputParam, ) from image_gen_aux import DepthPreprocessor class DepthProcessorBlock(ModularPipelineBlocks): @property def expected_components(self): return [ ComponentSpec( name="depth_processor", type_hint=DepthPreprocessor, repo="depth-anything/Depth-Anything-V2-Large-hf", ) ] @property def inputs(self) -> List[InputParam]: return [ InputParam( "image", required=True, description="Image(s) to use to extract depth maps", ) ] @property def intermediates_outputs(self) -> List[OutputParam]: return [ OutputParam( "condition_image", type_hint=torch.Tensor, description="Depth Map(s) of input Image(s)", ), ] @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) image = block_state.image depth_map = components.depth_processor(image) block_state.condition_image = depth_map.to(block_state.device) self.set_block_state(state, block_state) return components, state