package org.elasticsearch.xpack.ml.dataframe.steps;

import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner;
import org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.MlIndicesUtils;

/* loaded from: input_file:org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.class */
public class InferenceStep extends AbstractDataFrameAnalyticsStep {
    private static final Logger LOGGER = LogManager.getLogger(InferenceStep.class);
    private final ThreadPool threadPool;
    private final InferenceRunner inferenceRunner;

    public InferenceStep(NodeClient nodeClient, DataFrameAnalyticsTask dataFrameAnalyticsTask, DataFrameAnalyticsAuditor dataFrameAnalyticsAuditor, DataFrameAnalyticsConfig dataFrameAnalyticsConfig, ThreadPool threadPool, InferenceRunner inferenceRunner) {
        super(nodeClient, dataFrameAnalyticsTask, dataFrameAnalyticsAuditor, dataFrameAnalyticsConfig);
        this.threadPool = (ThreadPool) Objects.requireNonNull(threadPool);
        this.inferenceRunner = (InferenceRunner) Objects.requireNonNull(inferenceRunner);
    }

    @Override // org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep
    public DataFrameAnalyticsStep.Name name() {
        return DataFrameAnalyticsStep.Name.INFERENCE;
    }

    @Override // org.elasticsearch.xpack.ml.dataframe.steps.AbstractDataFrameAnalyticsStep
    protected void doExecute(ActionListener<StepResponse> actionListener) {
        if (!this.config.getAnalysis().supportsInference()) {
            LOGGER.debug(() -> {
                return new ParameterizedMessage("[{}] Inference step completed immediately as analysis does not support inference", this.config.getId());
            });
            actionListener.onResponse(new StepResponse(false));
            return;
        }
        CheckedConsumer checkedConsumer = str -> {
            runInference(str, actionListener);
        };
        Objects.requireNonNull(actionListener);
        ActionListener wrap = ActionListener.wrap(checkedConsumer, actionListener::onFailure);
        CheckedConsumer checkedConsumer2 = bool -> {
            if (bool.booleanValue()) {
                getModelId(wrap);
                return;
            }
            LOGGER.debug(() -> {
                return new ParameterizedMessage("[{}] Inference step completed immediately as there are no test docs", this.config.getId());
            });
            this.task.getStatsHolder().getProgressTracker().updateInferenceProgress(100);
            actionListener.onResponse(new StepResponse(isTaskStopping()));
        };
        Objects.requireNonNull(actionListener);
        ActionListener wrap2 = ActionListener.wrap(checkedConsumer2, actionListener::onFailure);
        CheckedConsumer checkedConsumer3 = refreshResponse -> {
            searchIfTestDocsExist(wrap2);
        };
        Objects.requireNonNull(actionListener);
        refreshDestAsync(ActionListener.wrap(checkedConsumer3, actionListener::onFailure));
    }

    private void runInference(String str, ActionListener<StepResponse> actionListener) {
        this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
            try {
                this.inferenceRunner.run(str);
                actionListener.onResponse(new StepResponse(isTaskStopping()));
            } catch (Exception e) {
                if (this.task.isStopping()) {
                    actionListener.onResponse(new StepResponse(false));
                } else {
                    actionListener.onFailure(e);
                }
            }
        });
    }

    private void searchIfTestDocsExist(ActionListener<Boolean> actionListener) {
        SearchRequest searchRequest = new SearchRequest(new String[]{this.config.getDest().getIndex()});
        searchRequest.indicesOptions(MlIndicesUtils.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS));
        searchRequest.source().query(QueryBuilders.boolQuery().mustNot(QueryBuilders.termQuery(this.config.getDest().getResultsField() + "." + DestinationIndex.IS_TRAINING, true)));
        searchRequest.source().size(0);
        searchRequest.source().trackTotalHitsUpTo(1);
        NodeClient nodeClient = this.client;
        SearchAction searchAction = SearchAction.INSTANCE;
        CheckedConsumer checkedConsumer = searchResponse -> {
            actionListener.onResponse(Boolean.valueOf(searchResponse.getHits().getTotalHits().value > 0));
        };
        Objects.requireNonNull(actionListener);
        ClientHelper.executeAsyncWithOrigin(nodeClient, "ml", searchAction, searchRequest, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private void getModelId(ActionListener<String> actionListener) {
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        searchSourceBuilder.size(1);
        searchSourceBuilder.fetchSource(false);
        searchSourceBuilder.query(QueryBuilders.boolQuery().filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), this.config.getId())));
        searchSourceBuilder.sort(TrainedModelConfig.CREATE_TIME.getPreferredName(), SortOrder.DESC);
        SearchRequest searchRequest = new SearchRequest(new String[]{".ml-inference-*"});
        searchRequest.source(searchSourceBuilder);
        NodeClient nodeClient = this.client;
        SearchAction searchAction = SearchAction.INSTANCE;
        CheckedConsumer checkedConsumer = searchResponse -> {
            SearchHit[] hits = searchResponse.getHits().getHits();
            if (hits.length == 0) {
                actionListener.onFailure(new ResourceNotFoundException("No model could be found to perform inference", new Object[0]));
            } else {
                actionListener.onResponse(hits[0].getId());
            }
        };
        Objects.requireNonNull(actionListener);
        ClientHelper.executeAsyncWithOrigin(nodeClient, "ml", searchAction, searchRequest, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    @Override // org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep
    public void cancel(String str, TimeValue timeValue) {
        this.inferenceRunner.cancel();
    }

    @Override // org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep
    public void updateProgress(ActionListener<Void> actionListener) {
        actionListener.onResponse((Object) null);
    }
}
