package org.elasticsearch.xpack.ml.dataframe.extractor;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.util.CachedSupplier;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.ProcessedField;

/* loaded from: input_file:org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.class */
public class DataFrameDataExtractor {
    private static final Logger LOGGER = LogManager.getLogger(DataFrameDataExtractor.class);
    public static final String NULL_VALUE = "��";
    private final Client client;
    private final DataFrameDataExtractorContext context;
    private boolean isCancelled;
    private boolean hasNext;
    private boolean hasPreviousSearchFailed;
    private final CachedSupplier<TrainTestSplitter> trainTestSplitter;
    private final String[] organicFeatures;
    private final String[] processedFeatures;
    private long lastSortKey = -1;
    private final Map<String, ExtractedField> extractedFieldsByName = new LinkedHashMap();

    /* loaded from: input_file:org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor$DataSummary.class */
    public static class DataSummary {
        public final long rows;
        public final int cols;

        public DataSummary(long j, int i) {
            this.rows = j;
            this.cols = i;
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor$Row.class */
    public static class Row {
        private final SearchHit hit;

        @Nullable
        private final String[] values;
        private final boolean isTraining;

        private Row(String[] strArr, SearchHit searchHit, boolean z) {
            this.values = strArr;
            this.hit = searchHit;
            this.isTraining = z;
        }

        @Nullable
        public String[] getValues() {
            return this.values;
        }

        public SearchHit getHit() {
            return this.hit;
        }

        public boolean shouldSkip() {
            return this.values == null;
        }

        public boolean isTraining() {
            return this.isTraining;
        }

        public int getChecksum() {
            return (int) getSortKey();
        }

        public long getSortKey() {
            return ((Long) this.hit.getSortValues()[0]).longValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DataFrameDataExtractor(Client client, DataFrameDataExtractorContext dataFrameDataExtractorContext) {
        this.client = (Client) Objects.requireNonNull(client);
        this.context = (DataFrameDataExtractorContext) Objects.requireNonNull(dataFrameDataExtractorContext);
        this.organicFeatures = dataFrameDataExtractorContext.extractedFields.extractOrganicFeatureNames();
        this.processedFeatures = dataFrameDataExtractorContext.extractedFields.extractProcessedFeatureNames();
        dataFrameDataExtractorContext.extractedFields.getAllFields().forEach(extractedField -> {
            this.extractedFieldsByName.put(extractedField.getName(), extractedField);
        });
        this.hasNext = true;
        this.hasPreviousSearchFailed = false;
        TrainTestSplitterFactory trainTestSplitterFactory = dataFrameDataExtractorContext.trainTestSplitterFactory;
        Objects.requireNonNull(trainTestSplitterFactory);
        this.trainTestSplitter = new CachedSupplier<>(trainTestSplitterFactory::create);
    }

    public Map<String, String> getHeaders() {
        return Collections.unmodifiableMap(this.context.headers);
    }

    public boolean hasNext() {
        return this.hasNext;
    }

    public boolean isCancelled() {
        return this.isCancelled;
    }

    public void cancel() {
        LOGGER.debug(() -> {
            return new ParameterizedMessage("[{}] Data extractor was cancelled", this.context.jobId);
        });
        this.isCancelled = true;
    }

    public Optional<List<Row>> next() throws IOException {
        if (!hasNext()) {
            throw new NoSuchElementException();
        }
        Optional<List<Row>> ofNullable = Optional.ofNullable(nextSearch());
        if (!ofNullable.isPresent() || ofNullable.get().isEmpty()) {
            this.hasNext = false;
        } else {
            this.lastSortKey = ofNullable.get().get(ofNullable.get().size() - 1).getSortKey();
        }
        return ofNullable;
    }

    public void preview(ActionListener<List<Row>> actionListener) {
        SearchRequestBuilder query = new SearchRequestBuilder(this.client, SearchAction.INSTANCE).setAllowPartialSearchResults(false).setIndices(this.context.indices).setSize(this.context.scrollSize).setQuery(QueryBuilders.boolQuery().filter(this.context.query));
        setFetchSource(query);
        for (ExtractedField extractedField : this.context.extractedFields.getDocValueFields()) {
            query.addDocValueField(extractedField.getSearchField(), extractedField.getDocValueFormat());
        }
        query.setRuntimeMappings(this.context.runtimeMappings);
        Map<String, String> map = this.context.headers;
        Client client = this.client;
        SearchAction searchAction = SearchAction.INSTANCE;
        SearchRequest request = query.request();
        CheckedConsumer checkedConsumer = searchResponse -> {
            if (searchResponse.getHits().getHits().length == 0) {
                actionListener.onResponse(Collections.emptyList());
                return;
            }
            SearchHit[] hits = searchResponse.getHits().getHits();
            ArrayList arrayList = new ArrayList(hits.length);
            for (SearchHit searchHit : hits) {
                String[] extractValues = extractValues(searchHit);
                arrayList.add(extractValues == null ? new Row(null, searchHit, true) : new Row(extractValues, searchHit, false));
            }
            actionListener.onResponse(arrayList);
        };
        Objects.requireNonNull(actionListener);
        ClientHelper.executeWithHeadersAsync(map, "ml", client, searchAction, request, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    protected List<Row> nextSearch() throws IOException {
        return tryRequestWithSearchResponse(() -> {
            return executeSearchRequest(buildSearchRequest());
        });
    }

    private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> supplier) throws IOException {
        try {
            SearchResponse searchResponse = supplier.get();
            LOGGER.trace(() -> {
                return new ParameterizedMessage("[{}] Search response was obtained", this.context.jobId);
            });
            List<Row> processSearchResponse = processSearchResponse(searchResponse);
            this.hasPreviousSearchFailed = false;
            return processSearchResponse;
        } catch (Exception e) {
            if (this.hasPreviousSearchFailed) {
                throw e;
            }
            LOGGER.warn(new ParameterizedMessage("[{}] Search resulted to failure; retrying once", this.context.jobId), e);
            markScrollAsErrored();
            return nextSearch();
        }
    }

    protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
        Map<String, String> map = this.context.headers;
        Client client = this.client;
        Objects.requireNonNull(searchRequestBuilder);
        return ClientHelper.executeWithHeaders(map, "ml", client, searchRequestBuilder::get);
    }

    private SearchRequestBuilder buildSearchRequest() {
        long j = this.lastSortKey + 1;
        long j2 = j + this.context.scrollSize;
        LOGGER.trace(() -> {
            return new ParameterizedMessage("[{}] Searching docs with [{}] in [{}, {})", new Object[]{this.context.jobId, DestinationIndex.INCREMENTAL_ID, Long.valueOf(j), Long.valueOf(j2)});
        });
        SearchRequestBuilder size = new SearchRequestBuilder(this.client, SearchAction.INSTANCE).setAllowPartialSearchResults(false).addSort(DestinationIndex.INCREMENTAL_ID, SortOrder.ASC).setIndices(this.context.indices).setSize(this.context.scrollSize);
        size.setQuery(QueryBuilders.boolQuery().filter(this.context.query).filter(QueryBuilders.rangeQuery(DestinationIndex.INCREMENTAL_ID).gte(Long.valueOf(j)).lt(Long.valueOf(j2))));
        setFetchSource(size);
        for (ExtractedField extractedField : this.context.extractedFields.getDocValueFields()) {
            size.addDocValueField(extractedField.getSearchField(), extractedField.getDocValueFormat());
        }
        size.setRuntimeMappings(this.context.runtimeMappings);
        return size;
    }

    private void setFetchSource(SearchRequestBuilder searchRequestBuilder) {
        if (this.context.includeSource) {
            searchRequestBuilder.setFetchSource(true);
            return;
        }
        String[] sourceFields = this.context.extractedFields.getSourceFields();
        if (sourceFields.length != 0) {
            searchRequestBuilder.setFetchSource(sourceFields, (String[]) null);
        } else {
            searchRequestBuilder.setFetchSource(false);
            searchRequestBuilder.storedFields(new String[]{"_none_"});
        }
    }

    private List<Row> processSearchResponse(SearchResponse searchResponse) {
        if (searchResponse.getHits().getHits().length == 0) {
            this.hasNext = false;
            return null;
        }
        SearchHit[] hits = searchResponse.getHits().getHits();
        ArrayList arrayList = new ArrayList(hits.length);
        int length = hits.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            SearchHit searchHit = hits[i];
            if (this.isCancelled) {
                this.hasNext = false;
                break;
            }
            arrayList.add(createRow(searchHit));
            i++;
        }
        return arrayList;
    }

    private String extractNonProcessedValues(SearchHit searchHit, String str) {
        Object[] value = this.extractedFieldsByName.get(str).value(searchHit);
        if (value.length == 1 && isValidValue(value[0])) {
            return Objects.toString(value[0]);
        }
        if (value.length == 0 && this.context.supportsRowsWithMissingValues) {
            return NULL_VALUE;
        }
        return null;
    }

    private String[] extractProcessedValue(ProcessedField processedField, SearchHit searchHit) {
        Map<String, ExtractedField> map = this.extractedFieldsByName;
        Objects.requireNonNull(map);
        Object[] value = processedField.value(searchHit, (v1) -> {
            return r2.get(v1);
        });
        if (value.length == 0 && !this.context.supportsRowsWithMissingValues) {
            return null;
        }
        String[] strArr = new String[processedField.getOutputFieldNames().size()];
        for (int i = 0; i < processedField.getOutputFieldNames().size(); i++) {
            strArr[i] = NULL_VALUE;
        }
        if (value.length == 0) {
            return strArr;
        }
        if (value.length != processedField.getOutputFieldNames().size()) {
            throw ExceptionsHelper.badRequestException("field_processor [{}] output size expected to be [{}], instead it was [{}]", new Object[]{processedField.getProcessorName(), Integer.valueOf(processedField.getOutputFieldNames().size()), Integer.valueOf(value.length)});
        }
        for (int i2 = 0; i2 < processedField.getOutputFieldNames().size(); i2++) {
            Object obj = value[i2];
            if (obj != null || !this.context.supportsRowsWithMissingValues) {
                if (!isValidValue(obj)) {
                    return null;
                }
                strArr[i2] = Objects.toString(obj);
            }
        }
        return strArr;
    }

    private Row createRow(SearchHit searchHit) {
        String[] extractValues = extractValues(searchHit);
        if (extractValues == null) {
            return new Row(null, searchHit, true);
        }
        boolean isTraining = ((TrainTestSplitter) this.trainTestSplitter.get()).isTraining(extractValues);
        Row row = new Row(extractValues, searchHit, isTraining);
        LOGGER.trace(() -> {
            return new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}", new Object[]{this.context.jobId, Long.valueOf(row.getSortKey()), Boolean.valueOf(isTraining), Arrays.toString(row.values)});
        });
        return row;
    }

    private String[] extractValues(SearchHit searchHit) {
        String[] strArr = new String[this.organicFeatures.length + this.processedFeatures.length];
        int i = 0;
        for (String str : this.organicFeatures) {
            String extractNonProcessedValues = extractNonProcessedValues(searchHit, str);
            if (extractNonProcessedValues == null) {
                return null;
            }
            int i2 = i;
            i++;
            strArr[i2] = extractNonProcessedValues;
        }
        Iterator<ProcessedField> it = this.context.extractedFields.getProcessedFields().iterator();
        while (it.hasNext()) {
            String[] extractProcessedValue = extractProcessedValue(it.next(), searchHit);
            if (extractProcessedValue == null) {
                return null;
            }
            for (String str2 : extractProcessedValue) {
                int i3 = i;
                i++;
                strArr[i3] = str2;
            }
        }
        return strArr;
    }

    private void markScrollAsErrored() {
        this.hasPreviousSearchFailed = true;
    }

    public List<String> getFieldNames() {
        return (List) Stream.concat(Arrays.stream(this.organicFeatures), Arrays.stream(this.processedFeatures)).collect(Collectors.toList());
    }

    public ExtractedFields getExtractedFields() {
        return this.context.extractedFields;
    }

    public DataSummary collectDataSummary() {
        long j = executeSearchRequest(buildDataSummarySearchRequestBuilder()).getHits().getTotalHits().value;
        LOGGER.debug(() -> {
            return new ParameterizedMessage("[{}] Data summary rows [{}]", this.context.jobId, Long.valueOf(j));
        });
        return new DataSummary(j, this.organicFeatures.length + this.processedFeatures.length);
    }

    public void collectDataSummaryAsync(ActionListener<DataSummary> actionListener) {
        SearchRequestBuilder buildDataSummarySearchRequestBuilder = buildDataSummarySearchRequestBuilder();
        int length = this.organicFeatures.length + this.processedFeatures.length;
        Map<String, String> map = this.context.headers;
        Client client = this.client;
        SearchAction searchAction = SearchAction.INSTANCE;
        SearchRequest request = buildDataSummarySearchRequestBuilder.request();
        CheckedConsumer checkedConsumer = searchResponse -> {
            actionListener.onResponse(new DataSummary(searchResponse.getHits().getTotalHits().value, length));
        };
        Objects.requireNonNull(actionListener);
        ClientHelper.executeWithHeadersAsync(map, "ml", client, searchAction, request, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private SearchRequestBuilder buildDataSummarySearchRequestBuilder() {
        QueryBuilder queryBuilder = this.context.query;
        if (!this.context.supportsRowsWithMissingValues) {
            queryBuilder = QueryBuilders.boolQuery().filter(queryBuilder).filter(allExtractedFieldsExistQuery());
        }
        return new SearchRequestBuilder(this.client, SearchAction.INSTANCE).setAllowPartialSearchResults(false).setIndices(this.context.indices).setSize(0).setQuery(queryBuilder).setTrackTotalHits(true).setRuntimeMappings(this.context.runtimeMappings);
    }

    private QueryBuilder allExtractedFieldsExistQuery() {
        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
        Iterator<ExtractedField> it = this.context.extractedFields.getAllFields().iterator();
        while (it.hasNext()) {
            boolQuery.filter(QueryBuilders.existsQuery(it.next().getName()));
        }
        return boolQuery;
    }

    public Set<String> getCategoricalFields(DataFrameAnalysis dataFrameAnalysis) {
        return ExtractedFieldsDetector.getCategoricalOutputFields(this.context.extractedFields, dataFrameAnalysis);
    }

    public static boolean isValidValue(Object obj) {
        return (obj instanceof Number) || (obj instanceof String) || (obj instanceof Boolean);
    }
}
