Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions clickhouse/base/wire_format.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <assert.h>
#include "wire_format.h"

#include "input.h"
Expand All @@ -6,6 +7,7 @@
#include "../exceptions.h"

#include <stdexcept>
#include <algorithm>

namespace {
constexpr int MAX_VARINT_BYTES = 10;
Expand Down Expand Up @@ -99,4 +101,77 @@ bool WireFormat::SkipString(InputStream& input) {
return false;
}

inline const char* find_quoted_chars(const char* start, const char* end)
{
static constexpr char quoted_chars[] = {'\0', '\b', '\t', '\n', '\'', '\\'};
const auto first = std::find_first_of(start, end, std::begin(quoted_chars), std::end(quoted_chars));

return (first == end) ? nullptr : first;
}

void WireFormat::WriteQuotedString(OutputStream& output, std::string_view value) {
auto size = value.size();
const char* start = value.data();
const char* end = start + size;
const char* quoted_char = find_quoted_chars(start, end);
if (quoted_char == nullptr) {
WriteVarint64(output, size + 2);
WriteAll(output, "'", 1);
WriteAll(output, start, size);
WriteAll(output, "'", 1);
return;
}

// calculate quoted chars count
int quoted_count = 1;
const char* next_quoted_char = quoted_char + 1;
while ((next_quoted_char = find_quoted_chars(next_quoted_char, end))) {
quoted_count++;
next_quoted_char++;
}
WriteVarint64(output, size + 2 + 3 * quoted_count); // length

WriteAll(output, "'", 1);

do {
auto write_size = quoted_char - start;
WriteAll(output, start, write_size);
WriteAll(output, "\\", 1);
char c = quoted_char[0];
switch (c) {
case '\0':
WriteAll(output, "x00", 3);
break;
case '\b':
WriteAll(output, "x08", 3);
break;
case '\t':
WriteAll(output, R"(\\t)", 3);
break;
case '\n':
WriteAll(output, R"(\\n)", 3);
break;
case '\'':
WriteAll(output, "x27", 3);
break;
case '\\':
WriteAll(output, R"(\\\)", 3);
break;
default:
break;
}
start = quoted_char + 1;
quoted_char = find_quoted_chars(start, end);
} while (quoted_char);

WriteAll(output, start, end - start);
WriteAll(output, "'", 1);
}

void WireFormat::WriteParamNullRepresentation(OutputStream& output) {
const std::string NULL_REPRESENTATION(R"('\\N')");
WriteVarint64(output, NULL_REPRESENTATION.size());
WriteAll(output, NULL_REPRESENTATION.data(), NULL_REPRESENTATION.size());
}

} // namespace clickhouse
2 changes: 2 additions & 0 deletions clickhouse/base/wire_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class WireFormat {
static void WriteFixed(OutputStream& output, const T& value);
static void WriteBytes(OutputStream& output, const void* buf, size_t len);
static void WriteString(OutputStream& output, std::string_view value);
static void WriteQuotedString(OutputStream& output, std::string_view value);
static void WriteParamNullRepresentation(OutputStream& output);
static void WriteUInt64(OutputStream& output, const uint64_t value);
static void WriteVarint64(OutputStream& output, uint64_t value);

Expand Down
55 changes: 51 additions & 4 deletions clickhouse/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@
#define DBMS_MIN_REVISION_WITH_DISTRIBUTED_DEPTH 54448
#define DBMS_MIN_REVISION_WITH_INITIAL_QUERY_START_TIME 54449
#define DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS 54451
#define DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS 54453
#define DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION 54454 // Client can get some fields in JSon format
#define DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM 54458 // send quota key after handshake
#define DBMS_MIN_PROTOCOL_REVISION_WITH_QUOTA_KEY 54458 // the same
#define DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS 54459

#define DMBS_PROTOCOL_REVISION DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS
#define DMBS_PROTOCOL_REVISION DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS

namespace clickhouse {

Expand Down Expand Up @@ -433,6 +438,11 @@ bool Client::Impl::Handshake() {
if (!ReceiveHello()) {
return false;
}

if (server_info_.revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM) {
WireFormat::WriteString(*output_, std::string());
}

return true;
}

Expand Down Expand Up @@ -502,7 +512,7 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) {
return false;
}
}
if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO)
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO)
{
if (!WireFormat::ReadUInt64(*input_, &info.written_rows)) {
return false;
Expand Down Expand Up @@ -589,7 +599,7 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) {

bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
// Additional information about block.
if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {
uint64_t num;
BlockInfo info;

Expand Down Expand Up @@ -635,6 +645,16 @@ bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
if (!WireFormat::ReadString(input, &type)) {
return false;
}

if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION) {
uint8_t custom_format_len;
if (!WireFormat::ReadFixed(input, &custom_format_len)) {
return false;
}
if (custom_format_len > 0) {
throw UnimplementedError(std::string("unsupported custom serialization"));
}
}

if (ColumnRef col = CreateColumnByType(type, create_column_settings)) {
if (num_rows && !col->Load(&input, num_rows)) {
Expand All @@ -653,7 +673,7 @@ bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
bool Client::Impl::ReceiveData() {
Block block;

if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {
if (!WireFormat::SkipString(*input_)) {
return false;
}
Expand Down Expand Up @@ -793,6 +813,12 @@ void Client::Impl::SendQuery(const Query& query) {
throw UnimplementedError(std::string("Can't send open telemetry tracing context to a server, server version is too old"));
}
}
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS) {
// replica dont supported by client
WireFormat::WriteUInt64(*output_, 0);
WireFormat::WriteUInt64(*output_, 0);
WireFormat::WriteUInt64(*output_, 0);
}
}

/// Per query settings
Expand All @@ -817,6 +843,22 @@ void Client::Impl::SendQuery(const Query& query) {
WireFormat::WriteUInt64(*output_, Stages::Complete);
WireFormat::WriteUInt64(*output_, compression_);
WireFormat::WriteString(*output_, query.GetText());

//Send params after query text
if (server_info_.revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS) {
for(const auto& [name, value] : query.GetParams()) {
// params is like query settings
WireFormat::WriteString(*output_, name);
const uint64_t Custom = 2;
WireFormat::WriteVarint64(*output_, Custom);
if (value)
WireFormat::WriteQuotedString(*output_, *value);
else
WireFormat::WriteParamNullRepresentation(*output_);
}
WireFormat::WriteString(*output_, std::string()); // empty string after last param
}

// Send empty block as marker of
// end of data
SendData(Block());
Expand All @@ -842,6 +884,11 @@ void Client::Impl::WriteBlock(const Block& block, OutputStream& output) {
WireFormat::WriteString(output, bi.Name());
WireFormat::WriteString(output, bi.Type()->GetName());

if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION) {
// TODO: custom serialization
WireFormat::WriteFixed<uint8_t>(output, 0);
}

// Empty columns are not serialized and occupy exactly 0 bytes.
// ref https://github.com/ClickHouse/ClickHouse/blob/39b37a3240f74f4871c8c1679910e065af6bea19/src/Formats/NativeWriter.cpp#L163
const bool containsData = block.GetRowCount() > 0;
Expand Down
15 changes: 15 additions & 0 deletions clickhouse/query.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ struct QuerySettingsField {
};

using QuerySettings = std::unordered_map<std::string, QuerySettingsField>;
using QueryParamValue = std::optional<std::string>;
using QueryParams = std::unordered_map<std::string, QueryParamValue>;

struct Profile {
uint64_t rows = 0;
Expand Down Expand Up @@ -115,6 +117,18 @@ class Query : public QueryEvents {
return *this;
}

inline const QueryParams& GetParams() const { return query_params_; }

inline Query& SetParams(QueryParams query_params) {
query_params_ = std::move(query_params);
return *this;
}

inline Query& SetParam(const std::string& name, const QueryParamValue& value) {
query_params_[name] = value;
return *this;
}

inline const std::optional<open_telemetry::TracingContext>& GetTracingContext() const {
return tracing_context_;
}
Expand Down Expand Up @@ -219,6 +233,7 @@ class Query : public QueryEvents {
const std::string query_id_;
std::optional<open_telemetry::TracingContext> tracing_context_;
QuerySettings query_settings_;
QueryParams query_params_;
ExceptionCallback exception_cb_;
ProgressCallback progress_cb_;
SelectCallback select_cb_;
Expand Down
77 changes: 77 additions & 0 deletions tests/simple/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,81 @@ inline void GenericExample(Client& client) {
client.Execute("DROP TEMPORARY TABLE test_client");
}

inline void ParamExample(Client& client) {
/// Create a table.
client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id UInt64, name String)");

{
Query query("insert into test_client values ({id: UInt64}, {name: String})");

query.SetParam("id", "1").SetParam("name", "NAME");
client.Execute(query);

query.SetParam("id", "123").SetParam("name", "FromParam");
client.Execute(query);

const char FirstPrintable = ' ';
char test_str1[FirstPrintable * 2 + 1];
for (unsigned int i = 0; i < FirstPrintable; i++) {
test_str1[i * 2] = 'A';
test_str1[i * 2 + 1] = i;
}
test_str1[int(FirstPrintable * 2)] = 'A';

query.SetParam("id", "333").SetParam("name", std::string(test_str1, FirstPrintable * 2 + 1));
client.Execute(query);

const char LastPrintable = 127;
unsigned char big_string[LastPrintable - FirstPrintable];
for (unsigned int i = 0; i < sizeof(big_string); i++) big_string[i] = i + FirstPrintable;
query.SetParam("id", "444").SetParam("name", std::string((char*)big_string, sizeof(big_string)));
client.Execute(query);

query.SetParam("id", "555")
.SetParam("name", "utf8Русский");
client.Execute(query);
}

/// Select values inserted in the previous step.
Query query ("SELECT id, name, length(name) FROM test_client where id > {a: Int32}");
query.SetParam("a", "4");
SelectCallback cb([](const Block& block)
{
std::cout << PrettyPrintBlock{block} << std::endl;
});
query.OnData(cb);
client.Select(query);
/// Delete table.
client.Execute("DROP TEMPORARY TABLE test_client");
}

inline void ParamNullExample(Client& client) {
client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id UInt64, name Nullable(String))");

Query query("insert into test_client values ({id: UInt64}, {name: Nullable(String)})");

query.SetParam("id", "123").SetParam("name", QueryParamValue());
client.Execute(query);

query.SetParam("id", "456").SetParam("name", "String Value");
client.Execute(query);

client.Select("SELECT id, name FROM test_client", [](const Block& block) {
for (size_t c = 0; c < block.GetRowCount(); ++c) {
std::cerr << block[0]->As<ColumnUInt64>()->At(c) << " ";

auto col_string = block[1]->As<ColumnNullable>();
if (col_string->IsNull(c)) {
std::cerr << "\\N\n";
} else {
std::cerr << col_string->Nested()->As<ColumnString>()->At(c) << "\n";
}
}
});

client.Execute("DROP TEMPORARY TABLE test_client");
}

inline void NullableExample(Client& client) {
/// Create a table.
client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id Nullable(UInt64), date Nullable(Date))");
Expand Down Expand Up @@ -478,6 +553,8 @@ inline void IPExample(Client &client) {
}

static void RunTests(Client& client) {
ParamExample(client);
ParamNullExample(client);
ArrayExample(client);
CancelableExample(client);
DateExample(client);
Expand Down
Loading
Loading