Skip to content

Add pipeline aggregations to NativeSearchQuery. #1809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 11, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -77,7 +77,6 @@
import org.elasticsearch.index.reindex.UpdateByQueryRequest;
import org.elasticsearch.index.reindex.UpdateByQueryRequestBuilder;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
@@ -1119,9 +1118,11 @@ private void prepareNativeSearch(NativeSearchQuery query, SearchSourceBuilder so
}

if (!isEmpty(query.getAggregations())) {
for (AbstractAggregationBuilder<?> aggregationBuilder : query.getAggregations()) {
sourceBuilder.aggregation(aggregationBuilder);
}
query.getAggregations().forEach(sourceBuilder::aggregation);
}

if (!isEmpty(query.getPipelineAggregations())) {
query.getPipelineAggregations().forEach(sourceBuilder::aggregation);
}

}
@@ -1144,9 +1145,11 @@ private void prepareNativeSearch(SearchRequestBuilder searchRequestBuilder, Nati
}

if (!isEmpty(nativeSearchQuery.getAggregations())) {
for (AbstractAggregationBuilder<?> aggregationBuilder : nativeSearchQuery.getAggregations()) {
searchRequestBuilder.addAggregation(aggregationBuilder);
}
nativeSearchQuery.getAggregations().forEach(searchRequestBuilder::addAggregation);
}

if (!isEmpty(nativeSearchQuery.getPipelineAggregations())) {
nativeSearchQuery.getPipelineAggregations().forEach(searchRequestBuilder::addAggregation);
}
}

Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.script.mustache.SearchTemplateRequestBuilder;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.sort.SortBuilder;
@@ -48,6 +49,7 @@ public class NativeSearchQuery extends AbstractQuery {
private final List<ScriptField> scriptFields = new ArrayList<>();
@Nullable private CollapseBuilder collapseBuilder;
@Nullable private List<AbstractAggregationBuilder<?>> aggregations;
@Nullable private List<PipelineAggregationBuilder> pipelineAggregations;
@Nullable private HighlightBuilder highlightBuilder;
@Nullable private HighlightBuilder.Field[] highlightFields;
@Nullable private List<IndexBoost> indicesBoost;
@@ -143,6 +145,11 @@ public List<AbstractAggregationBuilder<?>> getAggregations() {
return aggregations;
}

@Nullable
public List<PipelineAggregationBuilder> getPipelineAggregations() {
return pipelineAggregations;
}

public void addAggregation(AbstractAggregationBuilder<?> aggregationBuilder) {

if (aggregations == null) {
@@ -156,6 +163,10 @@ public void setAggregations(List<AbstractAggregationBuilder<?>> aggregations) {
this.aggregations = aggregations;
}

public void setPipelineAggregations(List<PipelineAggregationBuilder> pipelineAggregationBuilders) {
this.pipelineAggregations = pipelineAggregationBuilders;
}

@Nullable
public List<IndexBoost> getIndicesBoost() {
return indicesBoost;
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.script.mustache.SearchTemplateRequestBuilder;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.sort.SortBuilder;
@@ -55,6 +56,7 @@ public class NativeSearchQueryBuilder {
private final List<ScriptField> scriptFields = new ArrayList<>();
private final List<SortBuilder<?>> sortBuilders = new ArrayList<>();
private final List<AbstractAggregationBuilder<?>> aggregationBuilders = new ArrayList<>();
private final List<PipelineAggregationBuilder> pipelineAggregationBuilders = new ArrayList<>();
@Nullable private HighlightBuilder highlightBuilder;
@Nullable private HighlightBuilder.Field[] highlightFields;
private Pageable pageable = Pageable.unpaged();
@@ -105,6 +107,14 @@ public NativeSearchQueryBuilder addAggregation(AbstractAggregationBuilder<?> agg
return this;
}

/**
* @since 4.3
*/
public NativeSearchQueryBuilder addAggregation(PipelineAggregationBuilder pipelineAggregationBuilder) {
this.pipelineAggregationBuilders.add(pipelineAggregationBuilder);
return this;
}

public NativeSearchQueryBuilder withHighlightBuilder(HighlightBuilder highlightBuilder) {
this.highlightBuilder = highlightBuilder;
return this;
@@ -239,6 +249,10 @@ public NativeSearchQuery build() {
nativeSearchQuery.setAggregations(aggregationBuilders);
}

if (!isEmpty(pipelineAggregationBuilders)) {
nativeSearchQuery.setPipelineAggregations(pipelineAggregationBuilders);
}

if (minScore > 0) {
nativeSearchQuery.setMinScore(minScore);
}
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
import static org.assertj.core.api.Assertions.*;
import static org.elasticsearch.index.query.QueryBuilders.*;
import static org.elasticsearch.search.aggregations.AggregationBuilders.*;
import static org.elasticsearch.search.aggregations.PipelineAggregatorBuilders.*;
import static org.springframework.data.elasticsearch.annotations.FieldType.*;
import static org.springframework.data.elasticsearch.annotations.FieldType.Integer;

@@ -26,9 +27,14 @@
import java.util.List;

import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.pipeline.InternalStatsBucket;
import org.elasticsearch.search.aggregations.pipeline.ParsedStatsBucket;
import org.elasticsearch.search.aggregations.pipeline.StatsBucket;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
@@ -109,7 +115,7 @@ public void after() {
indexOperations.delete();
}

@Test
@Test // DATAES-96
public void shouldReturnAggregatedResponseForGivenSearchQuery() {

// given
@@ -130,6 +136,56 @@ public void shouldReturnAggregatedResponseForGivenSearchQuery() {
assertThat(searchHits.hasSearchHits()).isFalse();
}

@Test // #1255
@DisplayName("should work with pipeline aggregations")
void shouldWorkWithPipelineAggregations() {

IndexInitializer.init(operations.indexOps(PipelineAggsEntity.class));
operations.save( //
new PipelineAggsEntity("1-1", "one"), //
new PipelineAggsEntity("2-1", "two"), //
new PipelineAggsEntity("2-2", "two"), //
new PipelineAggsEntity("3-1", "three"), //
new PipelineAggsEntity("3-2", "three"), //
new PipelineAggsEntity("3-3", "three") //
); //

NativeSearchQuery searchQuery = new NativeSearchQueryBuilder() //
.withQuery(matchAllQuery()) //
.withSearchType(SearchType.DEFAULT) //
.addAggregation(terms("keyword_aggs").field("keyword")) //
.addAggregation(statsBucket("keyword_bucket_stats", "keyword_aggs._count")) //
.withMaxResults(0) //
.build();

SearchHits<PipelineAggsEntity> searchHits = operations.search(searchQuery, PipelineAggsEntity.class);

Aggregations aggregations = searchHits.getAggregations();
assertThat(aggregations).isNotNull();
assertThat(aggregations.asMap().get("keyword_aggs")).isNotNull();
Aggregation keyword_bucket_stats = aggregations.asMap().get("keyword_bucket_stats");
assertThat(keyword_bucket_stats).isInstanceOf(StatsBucket.class);
if (keyword_bucket_stats instanceof ParsedStatsBucket) {
// Rest client
ParsedStatsBucket statsBucket = (ParsedStatsBucket) keyword_bucket_stats;
assertThat(statsBucket.getMin()).isEqualTo(1.0);
assertThat(statsBucket.getMax()).isEqualTo(3.0);
assertThat(statsBucket.getAvg()).isEqualTo(2.0);
assertThat(statsBucket.getSum()).isEqualTo(6.0);
assertThat(statsBucket.getCount()).isEqualTo(3L);
}
if (keyword_bucket_stats instanceof InternalStatsBucket) {
// transport client
InternalStatsBucket statsBucket = (InternalStatsBucket) keyword_bucket_stats;
assertThat(statsBucket.getMin()).isEqualTo(1.0);
assertThat(statsBucket.getMax()).isEqualTo(3.0);
assertThat(statsBucket.getAvg()).isEqualTo(2.0);
assertThat(statsBucket.getSum()).isEqualTo(6.0);
assertThat(statsBucket.getCount()).isEqualTo(3L);
}
}

// region entities
@Document(indexName = "test-index-articles-core-aggregation")
static class ArticleEntity {

@@ -256,4 +312,34 @@ public IndexQuery buildIndex() {
}
}

@Document(indexName = "pipeline-aggs")
static class PipelineAggsEntity {
@Id private String id;
@Field(type = Keyword) private String keyword;

public PipelineAggsEntity() {}

public PipelineAggsEntity(String id, String keyword) {
this.id = id;
this.keyword = keyword;
}

public String getId() {
return id;
}

public void setId(String id) {
this.id = id;
}

public String getKeyword() {
return keyword;
}

public void setKeyword(String keyword) {
this.keyword = keyword;
}
}
// endregion

}