Skip to content

Federated learning

APIs for federated learning.

FederatedPipelineBase

Base class for Federated Learning pipelines.

build(self, config)

Constructs the federated pipeline. User does not need to modify.

Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def build(self, config: DictConfig):
    """Constructs the federated pipeline. User does not need to modify."""

    @dsl.pipeline()
    def pipeline_function():
        self._merge_config()
        prev = self.preprocess(self.config)
        prev_output = self._process_at_orchestrator(prev) if prev else None

        for iter in range(self.config.federated_config.max_iterations):
            cool_name = self.create_base_name()
            midprocess_input = []
            for silo_name, silo in self.config.federated_config.silos.items():
                train_output = self._train_in_silo_once(
                    silo_name, silo, prev_output, cool_name
                )
                midprocess_input += train_output

            prev = self.midprocess(self.config, input=midprocess_input)
            prev_output = self._process_at_orchestrator(prev) if prev else None

        prev = self.postprocess(self.config, input=prev_output)
        prev_output = self._process_at_orchestrator(prev) if prev else None

    return pipeline_function

create_base_name(self)

The training outputs from each silo will be stored at / on the central storage. By default, it will return a name of 2 word. User can override this method.

Returns:

Type Description
str

base_name

Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def create_base_name(self) -> str:
    """The training outputs from each silo will be stored at <base_name>/<silo_name> on the central storage.
    By default, it will return a name of 2 word. User can override this method.

    Returns:
        base_name
    """
    rv = "fl-"
    rv += generate_slug(2)
    return rv

midprocess(self, config, input)

User-defined midprocess step which reads outputs from train in each silo. The outputs will be distributed to each silo's datastore.

Returns:

Type Description
StepOutput

a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline in order.

Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def midprocess(self, config: DictConfig, input: list) -> StepOutput:
    """
    User-defined midprocess step which reads outputs from `train` in each silo. The outputs will be distributed to each silo's datastore.

    Returns:
        a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline **in order**.
    """
    pass

pipeline_instance(self, pipeline_function, config)

Creates an instance of the pipeline using arguments. User does not need to modify.

Source code in shrike/pipeline/federated_learning.py
def pipeline_instance(self, pipeline_function, config):
    """Creates an instance of the pipeline using arguments. User does not need to modify."""
    pipeline = pipeline_function()
    return pipeline

postprocess(self, config, input)

Optional user-defined postprocess step which reads outputs from train in each silo and writes to the noncompliant_datastore.

Returns:

Type Description
StepOutput

a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline in order.

Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def postprocess(self, config: DictConfig, input: list) -> StepOutput:
    """
    Optional user-defined postprocess step which reads outputs from `train` in each silo and writes to the `noncompliant_datastore`.

    Returns:
        a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline **in order**.
    """
    pass

preprocess(self, config, input=None)

Optional user-defined preprocess step. The outputs will be distributed to each silo's datastore.

Returns:

Type Description
StepOutput

a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline in order.

Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def preprocess(
    self, config: DictConfig, input: Optional[list] = None
) -> StepOutput:
    """
    Optional user-defined preprocess step. The outputs will be distributed to each silo's datastore.

    Returns:
        a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline **in order**.
    """
    pass

train(self, config, input, silo)

User-defined train step happening at each silo. This reads outputs from preprocess or midprocess, and sends outputs back to the noncompliant_datastore.

Returns:

Type Description
StepOutput

a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline in order.

Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def train(self, config: DictConfig, input: list, silo: DictConfig) -> StepOutput:
    """
    User-defined train step happening at each silo. This reads outputs from `preprocess` or `midprocess`, and sends outputs back to the `noncompliant_datastore`.

    Returns:
        a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline **in order**.
    """
    pass

StepOutput dataclass

Output object from preprocess/midprocess/postprocess/training step in a federated pipeline.