Skip to content

Spark .NET

Namespace containing Spark .NET utilities.

Inspired by the closed-source implementations:

  • https://dev.azure.com/msdata/Vienna/_git/aml-ds?version=GC94d20cb3f190e942b016c548308becc107fcede8&path=/recipes/signed-components/canary-hdi/run.py # noqa: E501
  • https://dev.azure.com/eemo/TEE/_git/TEEGit?version=GC8f000c8c61ae67cf1009d7a70753f0175968ef81&path=/Offline/SGI/src/python/src/sparknet/run.py # noqa: E501
  • https://o365exchange.visualstudio.com/O365%20Core/_git/EuclidOdinML?version=GC849128898f37d138725695fbba3992cd5f5f4474&path=/sources/dev/Projects/dotnet/OdinMLDotNet/DotnetSpark-0.4.0.py # noqa: E501

full_pyfile_path(spark, py_file_name)

Resolve the full HDFS path of a file out of the Spark Session's PyFiles object.

Source code in shrike/spark/spark_net.py
def full_pyfile_path(spark: SparkSession, py_file_name: str) -> str:
    """
    Resolve the full HDFS path of a file out of the Spark Session's PyFiles
    object.
    """
    spark_conf = spark.sparkContext.getConf()
    py_files = spark_conf.get("spark.yarn.dist.pyFiles")
    file_names = py_files.split(",")

    log.info(
        f"Searching py_files for {py_file_name}: '{py_files}'",
        category=DataCategory.PUBLIC,
    )

    for file_name in file_names:
        if file_name.split("/")[-1] == py_file_name:
            return file_name

    raise PublicValueError(f"py_files do not contain {py_file_name}: {py_files}")

get_default_spark_session()

Resolve a default Spark session for running Spark .NET applications.

Source code in shrike/spark/spark_net.py
def get_default_spark_session() -> SparkSession:
    """
    Resolve a default Spark session for running Spark .NET applications.
    """
    # https://stackoverflow.com/a/534847
    random = str(uuid.uuid4())[:8]
    app_name = f"spark-net-{random}"
    log.info(f"Application name: {app_name}", category=DataCategory.PUBLIC)
    rv = SparkSession.builder.appName(app_name).getOrCreate()
    return rv

java_args(spark, args)

Convert a Python list into the corresponding Java argument array.

Source code in shrike/spark/spark_net.py
def java_args(spark: SparkSession, args: List[str]):
    """
    Convert a Python list into the corresponding Java argument array.
    """
    rv = SparkContext._gateway.new_array(spark._jvm.java.lang.String, len(args))

    # https://stackoverflow.com/a/522578
    for index, arg in enumerate(args):
        rv[index] = arg

    return rv

run_spark_net(zip_file='--zipFile', binary_name='--binaryName', spark=None, args=None)

Easy entry point to one-line run a Spark .NET application. Simplest sample usage is:

run_spark_net_with_smart_args()

Source code in shrike/spark/spark_net.py
def run_spark_net(
    zip_file: str = "--zipFile",
    binary_name: str = "--binaryName",
    spark: Optional[SparkSession] = None,
    args: Optional[list] = None,
) -> None:
    """
    Easy entry point to one-line run a Spark .NET application. Simplest sample
    usage is:

    > run_spark_net_with_smart_args()
    """

    if not spark:
        spark = get_default_spark_session()

    if not args:
        args = sys.argv

    enable_compliant_logging()

    parser = argparse.ArgumentParser()
    parser.add_argument(zip_file, dest="ZIP_FILE")
    parser.add_argument(binary_name, dest="BINARY_NAME")

    try:
        (known_args, unknown_args) = parser.parse_known_args(args)

        zf = known_args.ZIP_FILE
        bn = known_args.BINARY_NAME
    except BaseException as e:
        raise PublicArgumentError(None, str(e)) from e

    run_spark_net_from_known_assembly(spark, zf, bn, unknown_args)

run_spark_net_from_known_assembly(spark, zip_file_name, assembly_name, args)

Invoke the binary assembly_name inside zip_file_name with the command line parameters args, using the provided Spark session. Print the Java stack trace if the job fails.

Source code in shrike/spark/spark_net.py
def run_spark_net_from_known_assembly(
    spark: SparkSession, zip_file_name: str, assembly_name: str, args: List[str]
) -> None:
    """
    Invoke the binary `assembly_name` inside `zip_file_name` with the command
    line parameters `args`, using the provided Spark session. Print the Java
    stack trace if the job fails.
    """
    fully_resolved_zip_file_name = full_pyfile_path(spark, zip_file_name)
    dotnet_args = [fully_resolved_zip_file_name, assembly_name] + args

    log.info(
        f"Calling dotnet with arguments: {dotnet_args}", category=DataCategory.PUBLIC
    )
    dotnet_args_java = java_args(spark, dotnet_args)

    message = None

    try:
        spark._jvm.org.apache.spark.deploy.dotnet.DotnetRunner.main(dotnet_args_java)
    except py4j.protocol.Py4JJavaError as err:
        log.error("Dotnet failed", category=DataCategory.PUBLIC)
        for line in err.java_exception.getStackTrace():
            log.error(str(line), category=DataCategory.PUBLIC)

        message = f"{err.errmsg} {err.java_exception}"

    if message:
        # Don't re-raise the existing exception since it's unprintable.
        # https://github.com/bartdag/py4j/issues/306
        raise PublicRuntimeError(message)

    log.info("Done running dotnet", category=DataCategory.PUBLIC)