Skip to content

Commit

Permalink
Avro: Support variant in Avro readers, writers.
Browse files Browse the repository at this point in the history
  • Loading branch information
rdblue committed Mar 5, 2025
1 parent 6673422 commit c81f8be
Show file tree
Hide file tree
Showing 17 changed files with 224 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@

import java.nio.ByteBuffer;

interface Serialized {
public interface Serialized {
ByteBuffer buffer();
}
1 change: 1 addition & 0 deletions core/src/main/java/org/apache/iceberg/avro/Avro.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ private enum Codec {
LogicalTypes.register(VariantLogicalType.NAME, schema -> VariantLogicalType.get());
DEFAULT_MODEL.addLogicalTypeConversion(new Conversions.DecimalConversion());
DEFAULT_MODEL.addLogicalTypeConversion(new UUIDConversion());
DEFAULT_MODEL.addLogicalTypeConversion(new VariantConversion());
}

public static WriteBuilder write(OutputFile file) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ public R map(P partner, Schema map, R valueResult) {
return null;
}

public R variant(P partner, R metadataResult, R valueResult) {
throw new UnsupportedOperationException("Visitor does not support variant");
}

public R primitive(P partner, Schema primitive) {
return null;
}
Expand All @@ -100,7 +104,11 @@ public static <P, R> R visit(
PartnerAccessors<P> accessors) {
switch (schema.getType()) {
case RECORD:
return visitRecord(partner, schema, visitor, accessors);
if (schema.getLogicalType() instanceof VariantLogicalType) {
return visitVariant(partner, schema, visitor, accessors);
} else {
return visitRecord(partner, schema, visitor, accessors);
}

case UNION:
return visitUnion(partner, schema, visitor, accessors);
Expand All @@ -123,6 +131,27 @@ public static <P, R> R visit(
}
}

private static <P, R> R visitVariant(
P partner,
Schema variant,
AvroWithPartnerVisitor<P, R> visitor,
PartnerAccessors<P> accessors) {
// check to make sure this hasn't been visited before
String recordName = variant.getFullName();
Preconditions.checkState(
!visitor.recordLevels.contains(recordName),
"Cannot process recursive Avro record %s",
recordName);
visitor.recordLevels.push(recordName);

R metadataResult = visit(null, variant.getField("metadata").schema(), visitor, accessors);
R valueResult = visit(null, variant.getField("value").schema(), visitor, accessors);

visitor.recordLevels.pop();

return visitor.variant(partner, metadataResult, valueResult);
}

private static <P, R> R visitRecord(
P partnerStruct,
Schema record,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ public ValueWriter<?> map(Schema map, ValueWriter<?> valueWriter) {
return ValueWriters.map(ValueWriters.strings(), valueWriter);
}

@Override
public ValueWriter<?> variant(
Schema variant, ValueWriter<?> metadataResult, ValueWriter<?> valueResult) {
return ValueWriters.variants(metadataResult, valueResult);
}

@Override
public ValueWriter<?> primitive(Schema primitive) {
LogicalType logicalType = primitive.getLogicalType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ public ValueReader<?> map(Type partner, Schema map, ValueReader<?> valueReader)
return ValueReaders.map(ValueReaders.strings(), valueReader);
}

@Override
public ValueReader<?> variant(
Type partner, ValueReader<?> metadataReader, ValueReader<?> valueReader) {
return ValueReaders.variants(metadataReader, valueReader);
}

@Override
public ValueReader<?> primitive(Type partner, Schema primitive) {
LogicalType logicalType = primitive.getLogicalType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ public ValueReader<?> map(Pair<Integer, Type> partner, Schema map, ValueReader<?
return ValueReaders.map(ValueReaders.strings(), valueReader);
}

@Override
public ValueReader<?> variant(
Pair<Integer, Type> partner, ValueReader<?> metadataReader, ValueReader<?> valueReader) {
return ValueReaders.variants(metadataReader, valueReader);
}

@Override
public ValueReader<?> primitive(Pair<Integer, Type> partner, Schema primitive) {
LogicalType logicalType = primitive.getLogicalType();
Expand Down
13 changes: 9 additions & 4 deletions core/src/main/java/org/apache/iceberg/avro/ValueReaders.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ public static ValueReader<byte[]> decimalBytesReader(Schema schema) {
}
}

@SuppressWarnings("unchecked")
public static ValueReader<Variant> variants(
ValueReader<ByteBuffer> metadataReader, ValueReader<ByteBuffer> valueReader) {
return new VariantReader(metadataReader, valueReader);
ValueReader<?> metadataReader, ValueReader<?> valueReader) {
return new VariantReader(
(ValueReader<ByteBuffer>) metadataReader, (ValueReader<ByteBuffer>) valueReader);
}

public static ValueReader<Object> union(List<ValueReader<?>> readers) {
Expand Down Expand Up @@ -673,8 +675,11 @@ public VariantReader(

@Override
public Variant read(Decoder decoder, Object reuse) throws IOException {
VariantMetadata metadata = VariantMetadata.from(metadataReader.read(decoder, null));
VariantValue value = VariantValue.from(metadata, metadataReader.read(decoder, null));
VariantMetadata metadata =
VariantMetadata.from(metadataReader.read(decoder, null).order(ByteOrder.LITTLE_ENDIAN));
VariantValue value =
VariantValue.from(
metadata, metadataReader.read(decoder, null).order(ByteOrder.LITTLE_ENDIAN));
return Variant.of(metadata, value);
}

Expand Down
46 changes: 46 additions & 0 deletions core/src/main/java/org/apache/iceberg/avro/ValueWriters.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.util.DecimalUtil;
import org.apache.iceberg.util.UUIDUtil;
import org.apache.iceberg.variants.Serialized;
import org.apache.iceberg.variants.Variant;
import org.apache.iceberg.variants.VariantMetadata;
import org.apache.iceberg.variants.VariantValue;

public class ValueWriters {
private ValueWriters() {}
Expand Down Expand Up @@ -109,6 +113,13 @@ public static ValueWriter<BigDecimal> decimal(int precision, int scale) {
return new DecimalWriter(precision, scale);
}

@SuppressWarnings("unchecked")
public static ValueWriter<Variant> variants(
ValueWriter<?> metadataWriter, ValueWriter<?> valueWriter) {
return new VariantWriter(
(ValueWriter<ByteBuffer>) metadataWriter, (ValueWriter<ByteBuffer>) valueWriter);
}

public static <T> ValueWriter<T> option(int nullIndex, ValueWriter<T> writer) {
return new OptionWriter<>(nullIndex, writer);
}
Expand Down Expand Up @@ -373,6 +384,41 @@ public void write(BigDecimal decimal, Encoder encoder) throws IOException {
}
}

private static class VariantWriter implements ValueWriter<Variant> {
private final ValueWriter<ByteBuffer> metadataWriter;
private final ValueWriter<ByteBuffer> valueWriter;

private VariantWriter(
ValueWriter<ByteBuffer> metadataWriter, ValueWriter<ByteBuffer> valueWriter) {
this.metadataWriter = metadataWriter;
this.valueWriter = valueWriter;
}

@Override
public void write(Variant variant, Encoder encoder) throws IOException {
VariantMetadata metadata = variant.metadata();
if (metadata instanceof Serialized) {
metadataWriter.write(((Serialized) metadata).buffer(), encoder);
} else {
// TODO: reuse buffers using buffer size code from Parquet
ByteBuffer metadataBuffer =
ByteBuffer.allocate(metadata.sizeInBytes()).order(ByteOrder.LITTLE_ENDIAN);
variant.metadata().writeTo(metadataBuffer, 0);
metadataWriter.write(metadataBuffer, encoder);
}

VariantValue value = variant.value();
if (value instanceof Serialized) {
valueWriter.write(((Serialized) value).buffer(), encoder);
} else {
ByteBuffer valueBuffer =
ByteBuffer.allocate(variant.value().sizeInBytes()).order(ByteOrder.LITTLE_ENDIAN);
variant.value().writeTo(valueBuffer, 0);
valueWriter.write(valueBuffer, encoder);
}
}
}

private static class OptionWriter<T> implements ValueWriter<T> {
private final int nullIndex;
private final int valueIndex;
Expand Down
68 changes: 68 additions & 0 deletions core/src/main/java/org/apache/iceberg/avro/VariantConversion.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.avro;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import org.apache.avro.Conversion;
import org.apache.avro.LogicalType;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.generic.IndexedRecord;
import org.apache.iceberg.variants.Variant;
import org.apache.iceberg.variants.VariantMetadata;
import org.apache.iceberg.variants.VariantValue;

public class VariantConversion extends Conversion<Variant> {
@Override
public Class<Variant> getConvertedType() {
return Variant.class;
}

@Override
public String getLogicalTypeName() {
return VariantLogicalType.NAME;
}

@Override
public Variant fromRecord(IndexedRecord record, Schema schema, LogicalType type) {
int metadataPos = schema.getField("metadata").pos();
int valuePos = schema.getField("value").pos();
VariantMetadata metadata = VariantMetadata.from((ByteBuffer) record.get(metadataPos));
VariantValue value = VariantValue.from(metadata, (ByteBuffer) record.get(valuePos));
return Variant.of(metadata, value);
}

@Override
public IndexedRecord toRecord(Variant variant, Schema schema, LogicalType type) {
int metadataPos = schema.getField("metadata").pos();
int valuePos = schema.getField("value").pos();
GenericRecord record = new GenericData.Record(schema);
ByteBuffer metadataBuffer =
ByteBuffer.allocate(variant.metadata().sizeInBytes()).order(ByteOrder.LITTLE_ENDIAN);
variant.metadata().writeTo(metadataBuffer, 0);
record.put(metadataPos, metadataBuffer);
ByteBuffer valueBuffer =
ByteBuffer.allocate(variant.value().sizeInBytes()).order(ByteOrder.LITTLE_ENDIAN);
variant.value().writeTo(valueBuffer, 0);
record.put(valuePos, valueBuffer);
return record;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ public ValueWriter<?> map(Schema map, ValueWriter<?> valueWriter) {
return ValueWriters.map(ValueWriters.strings(), valueWriter);
}

@Override
public ValueWriter<?> variant(
Schema variant, ValueWriter<?> metadataResult, ValueWriter<?> valueResult) {
return ValueWriters.variants(metadataResult, valueResult);
}

@Override
public ValueWriter<?> primitive(Schema primitive) {
LogicalType logicalType = primitive.getLogicalType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ public ValueReader<?> map(Type ignored, Schema map, ValueReader<?> valueReader)
return ValueReaders.map(ValueReaders.strings(), valueReader);
}

@Override
public ValueReader<?> variant(
Type partner, ValueReader<?> metadataReader, ValueReader<?> valueReader) {
return ValueReaders.variants(metadataReader, valueReader);
}

@Override
public ValueReader<?> primitive(Type partner, Schema primitive) {
LogicalType logicalType = primitive.getLogicalType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ private static void assertEquals(Type type, Object expected, Object actual) {
case DECIMAL:
assertThat(actual).as("Primitive value should be equal to expected").isEqualTo(expected);
break;
case VARIANT:
assertThat(expected).as("Expected should be a Variant").isInstanceOf(Variant.class);
assertThat(actual).as("Actual should be a Variant").isInstanceOf(Variant.class);
Variant expectedVariant = (Variant) expected;
Variant actualVariant = (Variant) actual;
VariantTestUtil.assertEqual(expectedVariant.metadata(), actualVariant.metadata());
VariantTestUtil.assertEqual(expectedVariant.value(), actualVariant.value());
break;
case STRUCT:
assertThat(expected).as("Expected should be a StructLike").isInstanceOf(StructLike.class);
assertThat(actual).as("Actual should be a StructLike").isInstanceOf(StructLike.class);
Expand Down
10 changes: 10 additions & 0 deletions core/src/test/java/org/apache/iceberg/avro/AvroTestHelpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.apache.avro.generic.GenericData.Record;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.variants.Variant;
import org.apache.iceberg.variants.VariantTestUtil;

class AvroTestHelpers {

Expand Down Expand Up @@ -135,6 +137,14 @@ private static void assertEquals(Type type, Object expected, Object actual) {
case DECIMAL:
assertThat(actual).as("Primitive value should be equal to expected").isEqualTo(expected);
break;
case VARIANT:
assertThat(expected).as("Expected should be a Variant").isInstanceOf(Variant.class);
assertThat(actual).as("Actual should be a Variant").isInstanceOf(Variant.class);
Variant expectedVariant = (Variant) expected;
Variant actualVariant = (Variant) actual;
VariantTestUtil.assertEqual(expectedVariant.metadata(), actualVariant.metadata());
VariantTestUtil.assertEqual(expectedVariant.value(), actualVariant.value());
break;
case STRUCT:
assertThat(expected).as("Expected should be a Record").isInstanceOf(Record.class);
assertThat(actual).as("Actual should be a Record").isInstanceOf(Record.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericData.Record;
import org.apache.avro.util.Utf8;
import org.apache.iceberg.RandomVariants;
import org.apache.iceberg.Schema;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.types.Type;
Expand Down Expand Up @@ -93,6 +94,11 @@ public Object map(Types.MapType map, Supplier<Object> keyResult, Supplier<Object
return RandomUtil.generateMap(random, map, keyResult, valueResult);
}

@Override
public Object variant(Types.VariantType variant) {
return RandomVariants.randomVariant(random);
}

@Override
public Object primitive(Type.PrimitiveType primitive) {
Object result = RandomUtil.generatePrimitive(primitive, random);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ protected boolean supportsTimestampNanos() {
return true;
}

@Override
protected boolean supportsVariant() {
return true;
}

@Override
protected void writeAndValidate(Schema schema) throws IOException {
List<Record> expected = RandomAvroData.generate(schema, 100, 0L);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ protected boolean supportsTimestampNanos() {
return true;
}

@Override
protected boolean supportsVariant() {
return true;
}

@Override
protected void writeAndValidate(Schema schema) throws IOException {
List<Record> expected = RandomInternalData.generate(schema, 100, 42L);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,9 @@ protected boolean supportsUnknown() {
protected boolean supportsTimestampNanos() {
return true;
}

@Override
protected boolean supportsVariant() {
return true;
}
}

0 comments on commit c81f8be

Please sign in to comment.