Federated learning
APIs for federated learning.
FederatedPipelineBase
Base class for Federated Learning pipelines.
build(self, config)
Constructs the federated pipeline. User does not need to modify.
Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def build(self, config: DictConfig):
"""Constructs the federated pipeline. User does not need to modify."""
@dsl.pipeline()
def pipeline_function():
self._merge_config()
prev = self.preprocess(self.config)
prev_output = self._process_at_orchestrator(prev) if prev else None
for iter in range(self.config.federated_config.max_iterations):
cool_name = self.create_base_name()
midprocess_input = []
for silo_name, silo in self.config.federated_config.silos.items():
train_output = self._train_in_silo_once(
silo_name, silo, prev_output, cool_name
)
midprocess_input += train_output
prev = self.midprocess(self.config, input=midprocess_input)
prev_output = self._process_at_orchestrator(prev) if prev else None
prev = self.postprocess(self.config, input=prev_output)
prev_output = self._process_at_orchestrator(prev) if prev else None
return pipeline_function
create_base_name(self)
The training outputs from each silo will be stored at
Returns:
Type | Description |
---|---|
str |
base_name |
Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def create_base_name(self) -> str:
"""The training outputs from each silo will be stored at <base_name>/<silo_name> on the central storage.
By default, it will return a name of 2 word. User can override this method.
Returns:
base_name
"""
rv = "fl-"
rv += generate_slug(2)
return rv
midprocess(self, config, input)
User-defined midprocess step which reads outputs from train
in each silo. The outputs will be distributed to each silo's datastore.
Returns:
Type | Description |
---|---|
StepOutput |
a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline in order. |
Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def midprocess(self, config: DictConfig, input: list) -> StepOutput:
"""
User-defined midprocess step which reads outputs from `train` in each silo. The outputs will be distributed to each silo's datastore.
Returns:
a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline **in order**.
"""
pass
pipeline_instance(self, pipeline_function, config)
Creates an instance of the pipeline using arguments. User does not need to modify.
Source code in shrike/pipeline/federated_learning.py
def pipeline_instance(self, pipeline_function, config):
"""Creates an instance of the pipeline using arguments. User does not need to modify."""
pipeline = pipeline_function()
return pipeline
postprocess(self, config, input)
Optional user-defined postprocess step which reads outputs from train
in each silo and writes to the noncompliant_datastore
.
Returns:
Type | Description |
---|---|
StepOutput |
a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline in order. |
Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def postprocess(self, config: DictConfig, input: list) -> StepOutput:
"""
Optional user-defined postprocess step which reads outputs from `train` in each silo and writes to the `noncompliant_datastore`.
Returns:
a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline **in order**.
"""
pass
preprocess(self, config, input=None)
Optional user-defined preprocess step. The outputs will be distributed to each silo's datastore.
Returns:
Type | Description |
---|---|
StepOutput |
a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline in order. |
Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def preprocess(
self, config: DictConfig, input: Optional[list] = None
) -> StepOutput:
"""
Optional user-defined preprocess step. The outputs will be distributed to each silo's datastore.
Returns:
a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline **in order**.
"""
pass
train(self, config, input, silo)
User-defined train step happening at each silo. This reads outputs from preprocess
or midprocess
, and sends outputs back to the noncompliant_datastore
.
Returns:
Type | Description |
---|---|
StepOutput |
a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline in order. |
Source code in shrike/pipeline/federated_learning.py
@experimental(message=EXPERIMENTAL_WARNING_MSG)
def train(self, config: DictConfig, input: list, silo: DictConfig) -> StepOutput:
"""
User-defined train step happening at each silo. This reads outputs from `preprocess` or `midprocess`, and sends outputs back to the `noncompliant_datastore`.
Returns:
a component/subgraph instance, and a list of output dataset name to be passed to the downstream pipeline **in order**.
"""
pass
StepOutput
dataclass
Output object from preprocess/midprocess/postprocess/training step in a federated pipeline.