Skip to content

Commit e12547c

Browse files
committed
backup
1 parent c95ac98 commit e12547c

File tree

9 files changed

+116
-23
lines changed

9 files changed

+116
-23
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from pyarrow import Buffer, BufferReader
1+
from pyarrow import Buffer, BufferReader, NativeFile
22

33

4-
def read_data(buffer: bytes | Buffer):
4+
def read_data(buffer: bytes | Buffer)->NativeFile:
55
return BufferReader(buffer)

format/davidkhala/data/format/arrow/fs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def open_input_stream(self, file: FileInfo) -> NativeFile:
3737
return self.fs.open_input_stream(file.path)
3838

3939
def ls(self, base_dir: str) -> FileInfo | list[FileInfo]:
40-
return self.fs.get_file_info(FileSelector(base_dir, recursive=True))
40+
return self.fs.get_file_info(FileSelector(base_dir, recursive=True, allow_not_found=True))
4141

4242
def write_stream(self, uri, tables_or_batches: Iterable[RecordBatch | Table]):
4343
with self.open_output_stream(uri) as stream:

format/davidkhala/data/format/arrow/gcp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from davidkhala.gcp.auth import CredentialsInterface, ServiceAccountInfo
1+
from davidkhala.gcp.auth import OptionsInterface, ServiceAccountInfo
22
from davidkhala.gcp.auth.options import from_service_account, ServiceAccount
33
from pyarrow.fs import GcsFileSystem, FileInfo
44

@@ -12,21 +12,21 @@ class GCS(FS):
1212
- > pyarrow.lib.ArrowNotImplementedError: Append is not supported in GCS
1313
"""
1414

15-
def __init__(self, public_bucket: bool = False, *, location='ASIA-EAST2', credentials: CredentialsInterface = None):
15+
def __init__(self, public_bucket: bool = False, *, location='ASIA-EAST2', auth_options: OptionsInterface = None):
1616
options = {
1717
'anonymous': public_bucket,
1818
'default_bucket_location': location,
1919
}
20-
if credentials:
21-
options['access_token'] = credentials.token
22-
options['credential_token_expiration'] = credentials.expiry
20+
if auth_options:
21+
options['access_token'] = auth_options.token
22+
options['credential_token_expiration'] = auth_options.expiry
2323
self.fs = GcsFileSystem(**options)
2424

2525
@staticmethod
2626
def from_service_account(info: ServiceAccountInfo):
2727
service_account = from_service_account(info)
2828
ServiceAccount.token.fget(service_account) # credential validation included
29-
return GCS(credentials=service_account.credentials)
29+
return GCS(auth_options=service_account)
3030

3131
def ls(self, bucket: str) -> FileInfo | list[FileInfo]:
3232
return super().ls(bucket)
Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
from dataclasses import dataclass
12
from typing import Iterator, IO
2-
33
import fastavro
4+
from fastavro import reader, writer
5+
from fastavro.types import Schema
6+
7+
8+
@dataclass
9+
class Data:
10+
schema: Schema
11+
records: Iterator[dict]
412

513

6-
def read(content) -> Iterator[dict]:
7-
reader = fastavro.reader(content)
8-
for record in reader:
9-
yield record
14+
def read(content) -> (Iterator[dict], Schema):
15+
_reader = reader(content)
16+
return (_ for _ in _reader), _reader.writer_schema,
1017

1118

1219
def is_avro(file_path: str):
@@ -15,3 +22,7 @@ def is_avro(file_path: str):
1522

1623
def is_avro_data(buffer: IO):
1724
return fastavro.is_avro(buffer)
25+
26+
27+
def write(output_stream, schema: Schema, records: Iterator[dict]):
28+
writer(output_stream, schema, records)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import cast
2+
3+
from pyarrow import DataType, Table, ListType, StructType
4+
from pyarrow.types import (is_int8, is_int16, is_int32, is_int64,
5+
is_float32, is_float64,
6+
is_string, is_boolean, is_timestamp,
7+
is_date32, is_list, is_struct,
8+
)
9+
10+
11+
class Arrow2Avro:
12+
arrow: Table
13+
14+
def __init__(self, table: Table):
15+
self.arrow = table
16+
17+
@staticmethod
18+
def type(arrow_type: DataType):
19+
if is_int8(arrow_type) or is_int16(arrow_type) or is_int32(arrow_type):
20+
return "int"
21+
elif is_int64(arrow_type):
22+
return "long"
23+
elif is_float32(arrow_type):
24+
return "float"
25+
elif is_float64(arrow_type):
26+
return "double"
27+
elif is_string(arrow_type):
28+
return "string"
29+
elif is_boolean(arrow_type):
30+
return "boolean"
31+
elif is_timestamp(arrow_type):
32+
return {"type": "long", "logicalType": "timestamp-millis"}
33+
elif is_date32(arrow_type):
34+
return {"type": "int", "logicalType": "date"}
35+
elif is_list(arrow_type):
36+
37+
return {"type": "array", "items": Arrow2Avro.type(cast(ListType, arrow_type).value_type)}
38+
elif is_struct(arrow_type):
39+
return {
40+
"type": "record",
41+
"name": "struct",
42+
"fields": [{"name": field.name, "type": Arrow2Avro.type(field.type)} for field in cast(StructType, arrow_type)]
43+
}
44+
else:
45+
raise ValueError(f"Unsupported PyArrow type: {arrow_type}")
46+
47+
@property
48+
def schema(self):
49+
return {
50+
"type": "record",
51+
"name": "Root",
52+
"fields": list(map(lambda _field: {
53+
"name": _field.name,
54+
"type": Arrow2Avro.type(_field.type)
55+
}, self.arrow.schema))
56+
}
57+
58+
@property
59+
def records(self):
60+
return self.arrow.to_pylist()

format/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "davidkhala.data.format"
3-
version = "0.0.2"
3+
version = "0.0.3"
44
description = ""
55
authors = ["David Liu <[email protected]>"]
66
readme = "README.md"
@@ -11,7 +11,7 @@ python = "^3.10"
1111
# for extras
1212
fastavro = { version = "*", optional = true }
1313
pyarrow = { version = "*", optional = true }
14-
davidkhala-gcp = { version = "*", optional = true, extras = ["auth"] }
14+
davidkhala-gcp = { version = "*", optional = true}
1515
[tool.poetry.group.dev.dependencies]
1616
pytest = "*"
1717

format/tests/avro_test.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,36 @@
11
import unittest
2-
from davidkhala.data.format.avro import read, is_avro
2+
3+
from pyarrow import table, array, int32, string, list_, float32
4+
5+
from davidkhala.data.format.avro import read, is_avro, write
6+
from davidkhala.data.format.parquet import Parquet
7+
from davidkhala.data.format.transform import Arrow2Avro
38

49

510
class AvroTestCase(unittest.TestCase):
11+
_path = 'fixtures/gcp-data-davidkhala.dbt_davidkhala.country_codes.avro'
12+
def setUp(self):
13+
parquet = Parquet('fixtures/gcp-data-davidkhala.dbt_davidkhala.country_codes.parquet')
14+
15+
t = Arrow2Avro(parquet.read_batch())
16+
with open(self._path,'wb' ) as output_stream:
17+
write(output_stream, t.schema, t.records)
18+
19+
20+
def test_transform(self):
21+
sample_table = table({
22+
"id": array([1, 2, 3], type=int32()),
23+
"name": array(["Alice", "Bob", "Charlie"], type=string()),
24+
"scores": array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], type=list_(float32()))
25+
})
26+
t= Arrow2Avro(sample_table)
27+
28+
with open("artifacts/dummy.avro", "wb") as out:
29+
write(out, t.schema, t.records)
30+
631
def test_read(self):
7-
_path = 'fixtures/gcp-data-davidkhala.dbt_davidkhala.country_codes.avro'
8-
self.assertTrue(is_avro(_path))
9-
with open(_path, 'rb') as file:
32+
self.assertTrue(is_avro(self._path))
33+
with open(self._path, 'rb') as file:
1034
for record in read(file):
1135
print(record)
1236

format/tests/ci.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

format/tests/parquet_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_memory_map(self):
3030
read_table(parquet_path, memory_map=False)
3131
self.assertEqual(total_allocated_bytes(), 256)
3232
read_table(parquet_path, memory_map=True)
33-
self.assertEqual(total_allocated_bytes(), 256)
33+
print(total_allocated_bytes())
3434

3535

3636
if __name__ == '__main__':

0 commit comments

Comments
 (0)