|
41 | 41 | import org.apache.pinot.common.utils.RegexpPatternConverterUtils; |
42 | 42 | import org.apache.pinot.common.utils.request.RequestUtils; |
43 | 43 | import org.apache.pinot.segment.spi.AggregationFunctionType; |
| 44 | +import org.apache.pinot.spi.data.FieldSpec.DataType; |
44 | 45 | import org.apache.pinot.spi.exception.BadQueryRequestException; |
45 | 46 | import org.apache.pinot.sql.FilterKind; |
46 | 47 | import org.apache.pinot.sql.parsers.CalciteSqlParser; |
47 | 48 |
|
48 | 49 |
|
49 | 50 | public class RequestContextUtils { |
| 51 | + private static final String UNSUPPORTED_RHS_MESSAGE = |
| 52 | + "Pinot does not support column or function on the right-hand side of the predicate"; |
| 53 | + |
50 | 54 | private RequestContextUtils() { |
51 | 55 | } |
52 | 56 |
|
@@ -217,23 +221,23 @@ private static FilterContext getFilterInner(Function thriftFunction) { |
217 | 221 | case GREATER_THAN: |
218 | 222 | return FilterContext.forPredicate( |
219 | 223 | new RangePredicate(getExpression(operands.get(0)), false, getStringValue(operands.get(1)), false, |
220 | | - RangePredicate.UNBOUNDED, new LiteralContext(operands.get(1).getLiteral()).getType())); |
| 224 | + RangePredicate.UNBOUNDED, getLiteralType(operands.get(1)))); |
221 | 225 | case GREATER_THAN_OR_EQUAL: |
222 | 226 | return FilterContext.forPredicate( |
223 | 227 | new RangePredicate(getExpression(operands.get(0)), true, getStringValue(operands.get(1)), false, |
224 | | - RangePredicate.UNBOUNDED, new LiteralContext(operands.get(1).getLiteral()).getType())); |
| 228 | + RangePredicate.UNBOUNDED, getLiteralType(operands.get(1)))); |
225 | 229 | case LESS_THAN: |
226 | 230 | return FilterContext.forPredicate( |
227 | 231 | new RangePredicate(getExpression(operands.get(0)), false, RangePredicate.UNBOUNDED, false, |
228 | | - getStringValue(operands.get(1)), new LiteralContext(operands.get(1).getLiteral()).getType())); |
| 232 | + getStringValue(operands.get(1)), getLiteralType(operands.get(1)))); |
229 | 233 | case LESS_THAN_OR_EQUAL: |
230 | 234 | return FilterContext.forPredicate( |
231 | 235 | new RangePredicate(getExpression(operands.get(0)), false, RangePredicate.UNBOUNDED, true, |
232 | | - getStringValue(operands.get(1)), new LiteralContext(operands.get(1).getLiteral()).getType())); |
| 236 | + getStringValue(operands.get(1)), getLiteralType(operands.get(1)))); |
233 | 237 | case BETWEEN: |
234 | 238 | return FilterContext.forPredicate( |
235 | 239 | new RangePredicate(getExpression(operands.get(0)), true, getStringValue(operands.get(1)), true, |
236 | | - getStringValue(operands.get(2)), new LiteralContext(operands.get(1).getLiteral()).getType())); |
| 240 | + getStringValue(operands.get(2)), getLiteralType(operands.get(1)))); |
237 | 241 | case RANGE: |
238 | 242 | return FilterContext.forPredicate( |
239 | 243 | new RangePredicate(getExpression(operands.get(0)), getStringValue(operands.get(1)))); |
@@ -278,12 +282,7 @@ private static FilterContext getFilterInner(Function thriftFunction) { |
278 | 282 | } |
279 | 283 |
|
280 | 284 | public static String getStringValue(Expression thriftExpression) { |
281 | | - Literal literal = thriftExpression.getLiteral(); |
282 | | - if (literal == null) { |
283 | | - throw new BadQueryRequestException( |
284 | | - "Pinot does not support column or function on the right-hand side of the predicate"); |
285 | | - } |
286 | | - return RequestUtils.getLiteralString(literal); |
| 285 | + return evaluateLiteralValue(thriftExpression).getStringValue(); |
287 | 286 | } |
288 | 287 |
|
289 | 288 | /** |
@@ -402,23 +401,23 @@ private static FilterContext getFilterInner(FunctionContext filterFunction) { |
402 | 401 | case GREATER_THAN: |
403 | 402 | return FilterContext.forPredicate( |
404 | 403 | new RangePredicate(operands.get(0), false, getStringValue(operands.get(1)), false, RangePredicate.UNBOUNDED, |
405 | | - operands.get(1).getLiteral().getType())); |
| 404 | + getLiteralType(operands.get(1)))); |
406 | 405 | case GREATER_THAN_OR_EQUAL: |
407 | 406 | return FilterContext.forPredicate( |
408 | 407 | new RangePredicate(operands.get(0), true, getStringValue(operands.get(1)), false, RangePredicate.UNBOUNDED, |
409 | | - operands.get(1).getLiteral().getType())); |
| 408 | + getLiteralType(operands.get(1)))); |
410 | 409 | case LESS_THAN: |
411 | 410 | return FilterContext.forPredicate( |
412 | 411 | new RangePredicate(operands.get(0), false, RangePredicate.UNBOUNDED, false, getStringValue(operands.get(1)), |
413 | | - operands.get(1).getLiteral().getType())); |
| 412 | + getLiteralType(operands.get(1)))); |
414 | 413 | case LESS_THAN_OR_EQUAL: |
415 | 414 | return FilterContext.forPredicate( |
416 | 415 | new RangePredicate(operands.get(0), false, RangePredicate.UNBOUNDED, true, getStringValue(operands.get(1)), |
417 | | - operands.get(1).getLiteral().getType())); |
| 416 | + getLiteralType(operands.get(1)))); |
418 | 417 | case BETWEEN: |
419 | 418 | return FilterContext.forPredicate( |
420 | 419 | new RangePredicate(operands.get(0), true, getStringValue(operands.get(1)), true, |
421 | | - getStringValue(operands.get(2)), operands.get(1).getLiteral().getType())); |
| 420 | + getStringValue(operands.get(2)), getLiteralType(operands.get(1)))); |
422 | 421 | case RANGE: |
423 | 422 | return FilterContext.forPredicate(new RangePredicate(operands.get(0), getStringValue(operands.get(1)))); |
424 | 423 | case REGEXP_LIKE: |
@@ -462,11 +461,96 @@ private static FilterContext getFilterInner(FunctionContext filterFunction) { |
462 | 461 | // literal context doesn't support float, and we cannot differentiate explicit string literal and literal |
463 | 462 | // without explicit type, so we always convert the literal into string. |
464 | 463 | private static String getStringValue(ExpressionContext expressionContext) { |
465 | | - if (expressionContext.getType() != ExpressionContext.Type.LITERAL) { |
466 | | - throw new BadQueryRequestException( |
467 | | - "Pinot does not support column or function on the right-hand side of the predicate"); |
| 464 | + return evaluateLiteralValue(expressionContext).getStringValue(); |
| 465 | + } |
| 466 | + |
| 467 | + private static DataType getLiteralType(Expression thriftExpression) { |
| 468 | + return evaluateLiteralValue(thriftExpression).getType(); |
| 469 | + } |
| 470 | + |
| 471 | + private static DataType getLiteralType(ExpressionContext expressionContext) { |
| 472 | + return evaluateLiteralValue(expressionContext).getType(); |
| 473 | + } |
| 474 | + |
| 475 | + private static EvaluatedLiteralValue evaluateLiteralValue(Expression thriftExpression) { |
| 476 | + Literal literal = thriftExpression.getLiteral(); |
| 477 | + if (literal != null) { |
| 478 | + return fromLiteralContext(new LiteralContext(literal)); |
468 | 479 | } |
469 | | - return expressionContext.getLiteral().getStringValue(); |
| 480 | + Function function = thriftExpression.getFunctionCall(); |
| 481 | + if (function != null) { |
| 482 | + return evaluateFunctionLiteral(function.getOperator(), function.getOperands(), |
| 483 | + RequestContextUtils::evaluateLiteralValue); |
| 484 | + } |
| 485 | + throw new BadQueryRequestException(UNSUPPORTED_RHS_MESSAGE); |
| 486 | + } |
| 487 | + |
| 488 | + private static EvaluatedLiteralValue evaluateLiteralValue(ExpressionContext expressionContext) { |
| 489 | + if (expressionContext.getType() == ExpressionContext.Type.LITERAL) { |
| 490 | + return fromLiteralContext(expressionContext.getLiteral()); |
| 491 | + } |
| 492 | + if (expressionContext.getType() == ExpressionContext.Type.FUNCTION) { |
| 493 | + FunctionContext function = expressionContext.getFunction(); |
| 494 | + return evaluateFunctionLiteral(function.getFunctionName(), function.getArguments(), |
| 495 | + RequestContextUtils::evaluateLiteralValue); |
| 496 | + } |
| 497 | + throw new BadQueryRequestException(UNSUPPORTED_RHS_MESSAGE); |
| 498 | + } |
| 499 | + |
| 500 | + private static <T> EvaluatedLiteralValue evaluateFunctionLiteral(String functionName, List<T> operands, |
| 501 | + java.util.function.Function<T, EvaluatedLiteralValue> evaluator) { |
| 502 | + if (!functionName.equalsIgnoreCase("cast")) { |
| 503 | + throw new BadQueryRequestException(UNSUPPORTED_RHS_MESSAGE); |
| 504 | + } |
| 505 | + Preconditions.checkState(operands.size() == 2, "CAST function must have exactly 2 operands"); |
| 506 | + EvaluatedLiteralValue source = evaluator.apply(operands.get(0)); |
| 507 | + EvaluatedLiteralValue target = evaluator.apply(operands.get(1)); |
| 508 | + DataType targetType = getCastTargetType(target.getStringValue()); |
| 509 | + if (source.getType() == DataType.UNKNOWN) { |
| 510 | + return new EvaluatedLiteralValue(source.getStringValue(), targetType); |
| 511 | + } |
| 512 | + Object converted = targetType.convert(source.getStringValue()); |
| 513 | + return new EvaluatedLiteralValue(targetType.toString(converted), targetType); |
| 514 | + } |
| 515 | + |
| 516 | + private static DataType getCastTargetType(String targetTypeString) { |
| 517 | + switch (targetTypeString.toUpperCase()) { |
| 518 | + case "INT": |
| 519 | + case "INTEGER": |
| 520 | + return DataType.INT; |
| 521 | + case "LONG": |
| 522 | + case "BIGINT": |
| 523 | + return DataType.LONG; |
| 524 | + case "FLOAT": |
| 525 | + return DataType.FLOAT; |
| 526 | + case "DOUBLE": |
| 527 | + return DataType.DOUBLE; |
| 528 | + case "DECIMAL": |
| 529 | + case "BIGDECIMAL": |
| 530 | + case "BIG_DECIMAL": |
| 531 | + return DataType.BIG_DECIMAL; |
| 532 | + case "BOOL": |
| 533 | + case "BOOLEAN": |
| 534 | + return DataType.BOOLEAN; |
| 535 | + case "TIMESTAMP": |
| 536 | + return DataType.TIMESTAMP; |
| 537 | + case "STRING": |
| 538 | + case "VARCHAR": |
| 539 | + return DataType.STRING; |
| 540 | + case "JSON": |
| 541 | + return DataType.JSON; |
| 542 | + case "BYTES": |
| 543 | + case "VARBINARY": |
| 544 | + return DataType.BYTES; |
| 545 | + case "UUID": |
| 546 | + return DataType.UUID; |
| 547 | + default: |
| 548 | + throw new BadQueryRequestException("Unable to cast expression to type - " + targetTypeString); |
| 549 | + } |
| 550 | + } |
| 551 | + |
| 552 | + private static EvaluatedLiteralValue fromLiteralContext(LiteralContext literalContext) { |
| 553 | + return new EvaluatedLiteralValue(literalContext.getStringValue(), literalContext.getType()); |
470 | 554 | } |
471 | 555 |
|
472 | 556 | private static float[] getVectorValue(ExpressionContext expressionContext) { |
@@ -556,4 +640,22 @@ private static float getFloatValue(Expression thriftExpression) { |
556 | 640 | throw new IllegalStateException("Unsupported literal type: " + type); |
557 | 641 | } |
558 | 642 | } |
| 643 | + |
| 644 | + private static final class EvaluatedLiteralValue { |
| 645 | + private final String _stringValue; |
| 646 | + private final DataType _type; |
| 647 | + |
| 648 | + private EvaluatedLiteralValue(String stringValue, DataType type) { |
| 649 | + _stringValue = stringValue; |
| 650 | + _type = type; |
| 651 | + } |
| 652 | + |
| 653 | + private String getStringValue() { |
| 654 | + return _stringValue; |
| 655 | + } |
| 656 | + |
| 657 | + private DataType getType() { |
| 658 | + return _type; |
| 659 | + } |
| 660 | + } |
559 | 661 | } |
0 commit comments