Home Reference Source Test

packages/causality-layer/src/CausalNetRunner/runner.mixins.js

/**
 * This RunnerMixins class provide methods for runner class.
 * @class RunnerMixins
 * @extends BaseRunnerClass
 */
const RunnerMixins = ( BaseRunnerClass )=> class extends BaseRunnerClass{
    set NetParameters(parameters){
        this.netParameters = parameters;
    }
    set NetLayers(netLayers){
        this.netLayers = netLayers;
    }
    get NetParameters(){
        if(!this.netParameters){
            throw Error('netParameters is not set');
        }
        return this.netParameters;
    }
    get NetLayers(){
        if(!this.netLayers){
            throw Error('netLayers is not set');
        }
        return this.netLayers;
    }
    runOpFlow(value, flow, parameters){
        const R = this.R;
        const OpsRunner = R.addIndex(R.reduce)(R.__,{result: value, trace: {}}, flow);
        var {result, trace} = OpsRunner(({result, trace}, node, idx)=>{
            if(node.Parameter){
                let params = parameters[node.Parameter];
                result = result[node.Op](params, ...node.Args);
            }
            else{
                result = result[node.Op](...node.Args);    
            };
            trace[idx] = result.shape;
            return {result, trace};
        });
        return {result, trace};
    }

    runOpLayer(value, net, parameters){
        let { result, trace } = net(value, parameters);
        return {result, trace};
    }
    runLayer(value, layerConfigure, layerParameters){
        const {Name, Type, Flow, Net} = layerConfigure;
        if(Type === 'Tensor'){
            let {result, trace} = this.runOpFlow(value, Flow, layerParameters);
            return {[Name]: result, trace};
        }   
        else if(Type === 'Layer'){
            let {result, trace} = this.runOpLayer(value, Net, layerParameters);
            return {[Name]: result, trace};
        }   
        else{
            throw Error('type must be either Layer or Tensor');
        }
    }

    tracing(traces, name, trace){
        if(traces){
            traces.push({[name]: trace});
        }
    }
    run(layers, samples, parameters, traces=null){
        let pipeValue = {PipeInput: samples}, lastLayer = 'PipeInput';
        for(let layer of layers){
            let layerOutput = this.runLayer(pipeValue[lastLayer], layer, parameters[layer.Name]);
            pipeValue[layer.Name] = layerOutput[layer.Name];
            lastLayer = layer.Name;
            this.tracing(traces, layer.Name, layerOutput.trace);
        }
        if(traces){
            this.logger.debug({traces});
        }
        return pipeValue[lastLayer];
    }

    get Predictor(){
        let predictLayers = this.NetLayers.Predict;
        const PredictParametersLenses = ()=>this.NetParameters.PredictParameters;
        return (samples)=>{
            let predictParameters = PredictParametersLenses();
            return this.run(predictLayers, samples, predictParameters);
        };
    }
    get Encoder(){
        let encodeLayers = this.NetLayers.Encode;
        const EncodeParametersLenses = ()=>this.NetParameters.EncodeParameters;
        return (samples)=>{
            let encodeParameters = EncodeParametersLenses();
            return this.run(encodeLayers, samples, encodeParameters);
        };
    }
    get Decoder(){
        let decodeLayers = this.NetLayers.Decode;
        const DecodeParametersLenses = ()=>this.NetParameters.EncodeParameters;
        return (samples)=>{
            let decodeParameters = DecodeParametersLenses();
            return this.run(decodeLayers, samples, decodeParameters);
        };
    }
};

export default RunnerMixins;