Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 6225af2

Browse files
authored
[CROSSDATA-832] Make Elasticsearch Crossdata connector comply with the flattening algorithm requirements (#810)
* Adapted type tests to core template * Improved template flexibility * Added array type to Elasticsearch connector.
1 parent b45d568 commit 6225af2

6 files changed

Lines changed: 275 additions & 191 deletions

File tree

core/src/test/scala/org/apache/spark/sql/crossdata/test/SharedXDContextTypesTest.scala

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,12 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
5252
* the types test table.
5353
*/
5454

55+
val arrayFlattenTestColumn: String = "arraystructarraystruct" /* Column used to test flattening of arrays */
5556

5657
//Template: This is the template implementation and shouldn't be modified in any specific test
5758

58-
def doTypesTest(datasourceName: String): Unit = {
59-
for(executionType <- ExecutionType.Spark::ExecutionType.Native::Nil)
60-
datasourceName should s"provide the right types for $executionType execution" in {
61-
assumeEnvironmentIsUpAndRunning
62-
val dframe = sql("SELECT " + typesSet.map(_.colname).mkString(", ") + s" FROM $dataTypesTableName")
63-
for(
64-
(tpe, i) <- typesSet zipWithIndex;
65-
typeCheck <- tpe.typeCheck
66-
) typeCheck(dframe.collect(executionType).head(i))
67-
}
68-
69-
//Multi-level column flat test
59+
protected def multilevelFlattenTests(datasourceName: String): Unit = {
60+
//Multi-level column flatten test
7061

7162
it should "provide flattened column names through the `annotatedCollect` method" in {
7263
val dataFrame = sql("SELECT structofstruct.struct1.structField1 FROM typesCheckTable")
@@ -81,8 +72,24 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
8172
rows.length shouldBe 1
8273
}
8374

75+
it should "be able to flatten whole rows" in {
76+
val dataFrame = sql("SELECT * FROM typesCheckTable")
77+
val rows = dataFrame.flattenedCollect()
78+
val hasComposedTypes = rows.head.schema.fields exists { field =>
79+
field.dataType match {
80+
case _: StructType | _: ArrayType => true
81+
case _ => false
82+
}
83+
}
84+
hasComposedTypes shouldBe false
85+
}
86+
}
87+
88+
protected def arrayFlattenTests(datasourceName: String): Unit = {
89+
//Multi-level column, with nested arrays, flatten test
90+
8491
it should "be able to vertically flatten results for array columns" in {
85-
val dataFrame = sql(s"SELECT arraystructarraystruct FROM typesCheckTable")
92+
val dataFrame = sql(s"SELECT $arrayFlattenTestColumn FROM typesCheckTable")
8693
val res = dataFrame.flattenedCollect()
8794

8895
// No array columns should be found in the result schema
@@ -100,7 +107,7 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
100107
}
101108

102109
it should "correctly apply user limits to a vertically flattened array column" in {
103-
val dataFrame = sql(s"SELECT arraystructarraystruct FROM typesCheckTable LIMIT 1")
110+
val dataFrame = sql(s"SELECT $arrayFlattenTestColumn FROM typesCheckTable LIMIT 1")
104111
val res = dataFrame.flattenedCollect()
105112
res.length shouldBe 1
106113
}
@@ -113,6 +120,24 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
113120

114121
}
115122

123+
def doTypesTest(datasourceName: String): Unit = {
124+
125+
for(executionType <- ExecutionType.Spark::ExecutionType.Native::Nil)
126+
datasourceName should s"provide the right types for $executionType execution" in {
127+
assumeEnvironmentIsUpAndRunning
128+
val dframe = sql("SELECT " + typesSet.map(_.colname).mkString(", ") + s" FROM $dataTypesTableName")
129+
for(
130+
(tpe, i) <- typesSet zipWithIndex;
131+
typeCheck <- tpe.typeCheck
132+
) typeCheck(dframe.collect(executionType).head(i))
133+
}
134+
135+
multilevelFlattenTests(datasourceName)
136+
137+
arrayFlattenTests(datasourceName)
138+
139+
}
140+
116141
abstract override def saveTestData: Unit = {
117142
super.saveTestData
118143
require(saveTypesData > 0, emptyTypesSetError)
@@ -167,6 +192,7 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
167192
}
168193

169194
object SharedXDContextTypesTest {
195+
170196
val dataTypesTableName = "typesCheckTable"
171197
case class SparkSQLColDef(colname: String, sqlType: String, typeCheck: Option[Any => Unit] = None)
172198
object SparkSQLColDef {

elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchQueryProcessor.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
2323
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
2424
import org.apache.spark.sql.catalyst.plans.logical.{Limit, LogicalPlan}
2525
import org.apache.spark.sql.{Row, sources}
26-
import org.apache.spark.sql.sources.CatalystToCrossdataAdapter.{BaseLogicalPlan, FilterReport, ProjectReport, SimpleLogicalPlan, CrossdataExecutionPlan}
26+
import org.apache.spark.sql.sources.CatalystToCrossdataAdapter.{BaseLogicalPlan, CrossdataExecutionPlan, FilterReport, ProjectReport, SimpleLogicalPlan}
2727
import org.apache.spark.sql.sources.{CatalystToCrossdataAdapter, Filter => SourceFilter}
28-
import org.apache.spark.sql.types.{StructField, StructType}
28+
import org.apache.spark.sql.types.{StructField, StructType, ArrayType}
2929
import org.elasticsearch.action.search.SearchResponse
3030

3131
import scala.util.{Failure, Try}
@@ -101,8 +101,6 @@ class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters:
101101
case sources.StringStartsWith(attribute, value) => prefixQuery(attribute, value.toLowerCase)
102102
}
103103

104-
import scala.collection.JavaConversions._
105-
106104
val searchFilters = sFilters.collect {
107105
case sources.EqualTo(attribute, value) => termQuery(attribute, value)
108106
case sources.GreaterThan(attribute, value) => rangeQuery(attribute).from(value).includeLower(false)
@@ -131,6 +129,7 @@ class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters:
131129
val subDocuments = schemaProvided.toSeq flatMap {
132130
_.fields collect {
133131
case StructField(name, _: StructType, _, _) => name
132+
case StructField(name, ArrayType(_: StructType, _), _, _) => name
134133
}
135134
}
136135
val stringFields: Seq[String] = fields.view map (_.name) filterNot (subDocuments contains _)

elasticsearch/src/main/scala/com/stratio/crossdata/connector/elasticsearch/ElasticSearchRowConverter.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ object ElasticSearchRowConverter {
5858
// TODO: Note that if a nested subdocument is targeted, it won't work and this algorithm should be made recursive.
5959
(hitFields.get(name) orElse subDocuments.get(name)).flatMap(Option(_)) map {
6060
((value: Any) => enforceCorrectType(value, schemaMap(name))) compose {
61-
case hitField: SearchHitField => hitField.getValue
61+
case hitField: SearchHitField =>
62+
if(hitField.getValues.size()>1) hitField.getValues
63+
else hitField.getValue
6264
case other => other
6365
}
6466
} orNull
@@ -85,6 +87,7 @@ object ElasticSearchRowConverter {
8587
case DateType => toDate(value)
8688
case BinaryType => toBinary(value)
8789
case schema: StructType => toRow(value, schema)
90+
case ArrayType(elementType: DataType, _) => toArray(value, elementType)
8891
case _ =>
8992
sys.error(s"Unsupported datatype conversion [${value.getClass}},$desiredType]")
9093
value
@@ -174,4 +177,9 @@ object ElasticSearchRowConverter {
174177
case _ => sys.error(s"Unsupported datatype conversion [${value.getClass}},Row")
175178
}
176179

180+
def toArray(value: Any, elementType: DataType): Seq[Any] = value match {
181+
case arr: util.ArrayList[Any] =>
182+
arr.toArray.map(enforceCorrectType(_, elementType))
183+
}
184+
177185
}
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
/*
2+
* Copyright (C) 2015 Stratio (http://stratio.com)
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.stratio.crossdata.connector.elasticsearch
17+
18+
import java.util.{GregorianCalendar, UUID}
19+
20+
import com.sksamuel.elastic4s.ElasticDsl._
21+
import com.sksamuel.elastic4s.mappings.FieldType._
22+
import com.sksamuel.elastic4s.mappings.{MappingDefinition, TypedFieldDefinition}
23+
import com.stratio.common.utils.components.logger.impl.SparkLoggerComponent
24+
import com.typesafe.config.ConfigFactory
25+
import org.apache.spark.sql.Row
26+
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
27+
import org.apache.spark.sql.crossdata.test.SharedXDContextTypesTest
28+
import org.apache.spark.sql.crossdata.test.SharedXDContextTypesTest.SparkSQLColDef
29+
import org.joda.time.DateTime
30+
31+
trait ElasticDataTypes extends ElasticWithSharedContext
32+
with SharedXDContextTypesTest
33+
with ElasticSearchDataTypesDefaultConstants
34+
with SparkLoggerComponent {
35+
36+
override val dataTypesSparkOptions = Map(
37+
"resource" -> s"$Index/$Type",
38+
"es.nodes" -> s"$ElasticHost",
39+
"es.port" -> s"$ElasticRestPort",
40+
"es.nativePort" -> s"$ElasticNativePort",
41+
"es.cluster" -> s"$ElasticClusterName",
42+
"es.nodes.wan.only" -> "true",
43+
"es.read.field.as.array.include" -> Seq(
44+
"arrayint"
45+
).mkString(",")
46+
)
47+
48+
protected case class ESColumnData(elasticType: Option[TypedFieldDefinition], data: () => Any)
49+
protected object ESColumnData {
50+
def apply(data: () => Any): ESColumnData = ESColumnData(None, data)
51+
def apply(elasticType: TypedFieldDefinition, data: () => Any): ESColumnData = ESColumnData(Some(elasticType), data)
52+
}
53+
54+
55+
override val arrayFlattenTestColumn: String = "arraystruct"
56+
57+
protected val dataTest: Seq[(SparkSQLColDef, ESColumnData)] = Seq(
58+
(SparkSQLColDef("id", "INT", _ shouldBe a[java.lang.Integer]), ESColumnData("id" typed IntegerType, () => 1)),
59+
(SparkSQLColDef("age", "LONG", _ shouldBe a[java.lang.Long]), ESColumnData("age" typed LongType, () => 1)),
60+
(
61+
SparkSQLColDef("description", "STRING", _ shouldBe a[java.lang.String]),
62+
ESColumnData("description" typed StringType, () => "1")
63+
),
64+
(
65+
SparkSQLColDef("name", "STRING", _ shouldBe a[java.lang.String]),
66+
ESColumnData( "name" typed StringType index NotAnalyzed, () => "1")
67+
),
68+
(
69+
SparkSQLColDef("enrolled", "BOOLEAN", _ shouldBe a[java.lang.Boolean]),
70+
ESColumnData("enrolled" typed BooleanType, () => false)
71+
),
72+
(
73+
SparkSQLColDef("birthday", "DATE", _ shouldBe a [java.sql.Date]),
74+
ESColumnData("birthday" typed DateType, () => DateTime.parse(1980 + "-01-01T10:00:00-00:00").toDate)
75+
),
76+
(
77+
SparkSQLColDef("salary", "DOUBLE", _ shouldBe a[java.lang.Double]),
78+
ESColumnData("salary" typed DoubleType, () => 0.15)
79+
),
80+
(
81+
SparkSQLColDef("timecol", "TIMESTAMP", _ shouldBe a[java.sql.Timestamp]),
82+
ESColumnData(
83+
"timecol" typed DateType,
84+
() => new java.sql.Timestamp(new GregorianCalendar(1970, 0, 1, 0, 0, 0).getTimeInMillis)
85+
)
86+
),
87+
(
88+
SparkSQLColDef("float", "FLOAT", _ shouldBe a[java.lang.Float]),
89+
ESColumnData("float" typed FloatType, () => 0.15)
90+
),
91+
(
92+
SparkSQLColDef("binary", "BINARY", x => x.isInstanceOf[Array[Byte]] shouldBe true),
93+
ESColumnData("binary" typed BinaryType, () => "YWE=".getBytes)
94+
),
95+
(
96+
SparkSQLColDef("tinyint", "TINYINT", _ shouldBe a[java.lang.Byte]),
97+
ESColumnData("tinyint" typed ByteType, () => Byte.MinValue)
98+
),
99+
(
100+
SparkSQLColDef("smallint", "SMALLINT", _ shouldBe a[java.lang.Short]),
101+
ESColumnData("smallint" typed ShortType, () => Short.MaxValue)
102+
),
103+
(
104+
SparkSQLColDef("subdocument", "STRUCT<field1: INT>", _ shouldBe a [Row]),
105+
ESColumnData("subdocument" inner ("field1" typed IntegerType), () => Map( "field1" -> 15))
106+
),
107+
(
108+
SparkSQLColDef(
109+
"structofstruct",
110+
"STRUCT<field1: INT, struct1: STRUCT<structField1: INT>>",
111+
{ res =>
112+
res shouldBe a[GenericRowWithSchema]
113+
res.asInstanceOf[GenericRowWithSchema].get(1) shouldBe a[GenericRowWithSchema]
114+
}
115+
),
116+
ESColumnData(
117+
"structofstruct" inner ("field1" typed IntegerType, "struct1" inner("structField1" typed IntegerType)),
118+
() => Map("field1" -> 15, "struct1" -> Map("structField1" -> 42))
119+
)
120+
),
121+
(
122+
SparkSQLColDef("arrayint", "ARRAY<INT>", _ shouldBe a[Seq[_]]),
123+
ESColumnData(() => Seq(1,2,3,4))
124+
),
125+
(
126+
SparkSQLColDef("arraystruct", "ARRAY<STRUCT<field1: LONG, field2: LONG>>", _ shouldBe a[Seq[_]]),
127+
ESColumnData(
128+
"arraystruct" nested(
129+
"field1" typed LongType,
130+
"field2" typed LongType
131+
),
132+
() =>
133+
Array(
134+
Map(
135+
"field1" -> 11,
136+
"field2" -> 12
137+
),
138+
Map(
139+
"field1" -> 21,
140+
"field2" -> 22
141+
),
142+
Map(
143+
"field1" -> 31,
144+
"field2" -> 32
145+
)
146+
)
147+
)
148+
)/*,
149+
(
150+
SparkSQLColDef(
151+
"arraystructarraystruct",
152+
"ARRAY<STRUCT<stringfield: STRING, arrayfield: ARRAY<STRUCT<field1: INT, field2: INT>>>>",
153+
{ res =>
154+
res shouldBe a[Seq[_]]
155+
res.asInstanceOf[Seq[_]].head shouldBe a[Row]
156+
res.asInstanceOf[Seq[_]].head.asInstanceOf[Row].get(1) shouldBe a[Seq[_]]
157+
res.asInstanceOf[Seq[_]].head.asInstanceOf[Row].get(1).asInstanceOf[Seq[_]].head shouldBe a[Row]
158+
}
159+
),
160+
ESColumnData(
161+
"arraystructarraystruct" nested (
162+
"stringfield" typed StringType,
163+
"arrayfield" nested (
164+
"field1" typed IntegerType,
165+
"field2" typed IntegerType
166+
)
167+
),
168+
() => Array(
169+
Map(
170+
"stringfield" -> "hello",
171+
"arrayfield" -> Array(
172+
Map(
173+
"field1" -> 10,
174+
"field2" -> 20
175+
)
176+
)
177+
)
178+
)
179+
)
180+
)*/
181+
)
182+
183+
184+
override protected def typesSet: Seq[SparkSQLColDef] = dataTest.map(_._1)
185+
186+
187+
abstract override def saveTestData: Unit = {
188+
require(saveTypesData > 0, emptyTypesSetError)
189+
}
190+
191+
override def saveTypesData: Int = {
192+
client.get.execute {
193+
val fieldsData = dataTest map {
194+
case (SparkSQLColDef(fieldName, _, _), ESColumnData(_, data)) => (fieldName, data())
195+
}
196+
index into Index / Type fields (fieldsData: _*)
197+
}.await
198+
client.get.execute {
199+
flush index Index
200+
}.await
201+
1
202+
}
203+
204+
override def typeMapping(): MappingDefinition =
205+
Type fields (
206+
dataTest collect {
207+
case (_, ESColumnData(Some(mapping), _)) => mapping
208+
}: _*
209+
)
210+
211+
override val emptyTypesSetError: String = "Couldn't insert Elasticsearch types test data"
212+
213+
}
214+
215+
216+
trait ElasticSearchDataTypesDefaultConstants extends ElasticSearchDefaultConstants{
217+
private lazy val config = ConfigFactory.load()
218+
override val Index = s"idxname${UUID.randomUUID.toString.replaceAll("-", "")}"
219+
override val Type = s"typename${UUID.randomUUID.toString.replaceAll("-", "")}"
220+
221+
}

0 commit comments

Comments
 (0)