use std::{collections::HashMap, convert::TryInto, fmt::Display, sync::Arc};
use prost::Message;
use prost_types::{
field_descriptor_proto::{Label, Type},
DescriptorProto, FieldDescriptorProto,
};
use tonic::{
transport::{Channel, ClientTlsConfig},
Request, Streaming,
};
use crate::google::cloud::bigquery::storage::v1::{GetWriteStreamRequest, WriteStream, WriteStreamView};
use crate::{
auth::Authenticator,
error::BQError,
google::cloud::bigquery::storage::v1::{
append_rows_request::{self, MissingValueInterpretation, ProtoData},
big_query_write_client::BigQueryWriteClient,
AppendRowsRequest, AppendRowsResponse, ProtoSchema,
},
BIG_QUERY_V2_URL,
};
static BIG_QUERY_STORAGE_API_URL: &str = "https://bigquerystorage.googleapis.com";
static BIGQUERY_STORAGE_API_DOMAIN: &str = "bigquerystorage.googleapis.com";
#[derive(Clone, Copy)]
pub enum ColumnType {
Double,
Float,
Int64,
Uint64,
Int32,
Fixed64,
Fixed32,
Bool,
String,
Bytes,
Uint32,
Sfixed32,
Sfixed64,
Sint32,
Sint64,
}
impl From<ColumnType> for Type {
fn from(value: ColumnType) -> Self {
match value {
ColumnType::Double => Type::Double,
ColumnType::Float => Type::Float,
ColumnType::Int64 => Type::Int64,
ColumnType::Uint64 => Type::Uint64,
ColumnType::Int32 => Type::Int32,
ColumnType::Fixed64 => Type::Fixed64,
ColumnType::Fixed32 => Type::Fixed32,
ColumnType::Bool => Type::Bool,
ColumnType::String => Type::String,
ColumnType::Bytes => Type::Bytes,
ColumnType::Uint32 => Type::Uint32,
ColumnType::Sfixed32 => Type::Sfixed32,
ColumnType::Sfixed64 => Type::Sfixed64,
ColumnType::Sint32 => Type::Sint32,
ColumnType::Sint64 => Type::Sfixed64,
}
}
}
#[derive(Clone, Copy)]
pub enum ColumnMode {
Nullable,
Required,
Repeated,
}
impl From<ColumnMode> for Label {
fn from(value: ColumnMode) -> Self {
match value {
ColumnMode::Nullable => Label::Optional,
ColumnMode::Required => Label::Required,
ColumnMode::Repeated => Label::Repeated,
}
}
}
pub struct FieldDescriptor {
pub number: u32,
pub name: String,
pub typ: ColumnType,
pub mode: ColumnMode,
}
pub struct TableDescriptor {
pub field_descriptors: Vec<FieldDescriptor>,
}
pub struct StreamName {
project: String,
dataset: String,
table: String,
stream: String,
}
impl StreamName {
pub fn new(project: String, dataset: String, table: String, stream: String) -> StreamName {
StreamName {
project,
dataset,
table,
stream,
}
}
pub fn new_default(project: String, dataset: String, table: String) -> StreamName {
StreamName {
project,
dataset,
table,
stream: "_default".to_string(),
}
}
}
impl Display for StreamName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let StreamName {
project,
dataset,
table,
stream,
} = self;
f.write_fmt(format_args!(
"projects/{project}/datasets/{dataset}/tables/{table}/streams/{stream}"
))
}
}
#[derive(Clone)]
pub struct StorageApi {
write_client: BigQueryWriteClient<Channel>,
auth: Arc<dyn Authenticator>,
base_url: String,
}
impl StorageApi {
pub(crate) fn new(write_client: BigQueryWriteClient<Channel>, auth: Arc<dyn Authenticator>) -> Self {
Self {
write_client,
auth,
base_url: BIG_QUERY_V2_URL.to_string(),
}
}
pub(crate) async fn new_write_client() -> Result<BigQueryWriteClient<Channel>, BQError> {
let tls_config = ClientTlsConfig::new()
.domain_name(BIGQUERY_STORAGE_API_DOMAIN)
.with_native_roots();
let channel = Channel::from_static(BIG_QUERY_STORAGE_API_URL)
.tls_config(tls_config)?
.connect()
.await?;
let write_client = BigQueryWriteClient::new(channel);
Ok(write_client)
}
pub(crate) fn with_base_url(&mut self, base_url: String) -> &mut Self {
self.base_url = base_url;
self
}
pub async fn append_rows(
&mut self,
stream_name: &StreamName,
rows: append_rows_request::Rows,
trace_id: String,
) -> Result<Streaming<AppendRowsResponse>, BQError> {
let write_stream = stream_name.to_string();
let append_rows_request = AppendRowsRequest {
write_stream,
offset: None,
trace_id,
missing_value_interpretations: HashMap::new(),
default_missing_value_interpretation: MissingValueInterpretation::Unspecified.into(),
rows: Some(rows),
};
let req = self
.new_authorized_request(tokio_stream::iter(vec![append_rows_request]))
.await?;
let response = self.write_client.append_rows(req).await?;
let streaming = response.into_inner();
Ok(streaming)
}
pub fn create_rows<M: Message>(
table_descriptor: &TableDescriptor,
rows: &[M],
max_size_bytes: usize,
) -> (append_rows_request::Rows, usize) {
let field_descriptors = table_descriptor
.field_descriptors
.iter()
.map(|fd| {
let typ: Type = fd.typ.into();
let label: Label = fd.mode.into();
FieldDescriptorProto {
name: Some(fd.name.clone()),
number: Some(fd.number as i32),
label: Some(label.into()),
r#type: Some(typ.into()),
type_name: None,
extendee: None,
default_value: None,
oneof_index: None,
json_name: None,
options: None,
proto3_optional: None,
}
})
.collect();
let proto_descriptor = DescriptorProto {
name: Some("table_schema".to_string()),
field: field_descriptors,
extension: vec![],
nested_type: vec![],
enum_type: vec![],
extension_range: vec![],
oneof_decl: vec![],
options: None,
reserved_range: vec![],
reserved_name: vec![],
};
let proto_schema = ProtoSchema {
proto_descriptor: Some(proto_descriptor),
};
let mut serialized_rows = Vec::new();
let mut total_size = 0;
for row in rows {
let encoded_row = row.encode_to_vec();
let current_size = encoded_row.len();
if total_size + current_size > max_size_bytes {
break;
}
serialized_rows.push(encoded_row);
total_size += current_size;
}
let num_rows_processed = serialized_rows.len();
let proto_rows = crate::google::cloud::bigquery::storage::v1::ProtoRows { serialized_rows };
let proto_data = ProtoData {
writer_schema: Some(proto_schema),
rows: Some(proto_rows),
};
(append_rows_request::Rows::ProtoRows(proto_data), num_rows_processed)
}
async fn new_authorized_request<D>(&self, t: D) -> Result<Request<D>, BQError> {
let access_token = self.auth.access_token().await?;
let bearer_token = format!("Bearer {access_token}");
let bearer_value = bearer_token.as_str().try_into()?;
let mut req = Request::new(t);
let meta = req.metadata_mut();
meta.insert("authorization", bearer_value);
Ok(req)
}
pub async fn get_write_stream(
&mut self,
stream_name: &StreamName,
view: WriteStreamView,
) -> Result<WriteStream, BQError> {
let get_write_stream_request = GetWriteStreamRequest {
name: stream_name.to_string(),
view: view.into(),
};
let req = self.new_authorized_request(get_write_stream_request).await?;
let response = self.write_client.get_write_stream(req).await?;
let write_stream = response.into_inner();
Ok(write_stream)
}
}
#[cfg(test)]
pub mod test {
use crate::model::dataset::Dataset;
use crate::model::field_type::FieldType;
use crate::model::table::Table;
use crate::model::table_field_schema::TableFieldSchema;
use crate::model::table_schema::TableSchema;
use crate::storage::{ColumnMode, ColumnType, FieldDescriptor, StorageApi, StreamName, TableDescriptor};
use crate::{env_vars, Client};
use prost::Message;
use std::time::{Duration, SystemTime};
use tokio_stream::StreamExt;
#[derive(Clone, PartialEq, Message)]
struct Actor {
#[prost(int32, tag = "1")]
actor_id: i32,
#[prost(string, tag = "2")]
first_name: String,
#[prost(string, tag = "3")]
last_name: String,
#[prost(string, tag = "4")]
last_update: String,
}
#[tokio::test]
async fn test_append_rows() -> Result<(), Box<dyn std::error::Error>> {
let (ref project_id, ref dataset_id, ref table_id, ref sa_key) = env_vars();
let dataset_id = &format!("{dataset_id}_storage");
let mut client = Client::from_service_account_key_file(sa_key).await?;
client.dataset().delete_if_exists(project_id, dataset_id, true).await;
let created_dataset = client.dataset().create(Dataset::new(project_id, dataset_id)).await?;
assert_eq!(created_dataset.id, Some(format!("{project_id}:{dataset_id}")));
let table = Table::new(
project_id,
dataset_id,
table_id,
TableSchema::new(vec![
TableFieldSchema::new("actor_id", FieldType::Int64),
TableFieldSchema::new("first_name", FieldType::String),
TableFieldSchema::new("last_name", FieldType::String),
TableFieldSchema::new("last_update", FieldType::Timestamp),
]),
);
let created_table = client
.table()
.create(
table
.description("A table used for unit tests")
.label("owner", "me")
.label("env", "prod")
.expiration_time(SystemTime::now() + Duration::from_secs(3600)),
)
.await?;
assert_eq!(created_table.table_reference.table_id, table_id.to_string());
let field_descriptors = vec![
FieldDescriptor {
name: "actor_id".to_string(),
number: 1,
typ: ColumnType::Int64,
mode: ColumnMode::Nullable,
},
FieldDescriptor {
name: "first_name".to_string(),
number: 2,
typ: ColumnType::String,
mode: ColumnMode::Nullable,
},
FieldDescriptor {
name: "last_name".to_string(),
number: 3,
typ: ColumnType::String,
mode: ColumnMode::Nullable,
},
FieldDescriptor {
name: "last_update".to_string(),
number: 4,
typ: ColumnType::String,
mode: ColumnMode::Nullable,
},
];
let table_descriptor = TableDescriptor { field_descriptors };
let actor1 = Actor {
actor_id: 1,
first_name: "John".to_string(),
last_name: "Doe".to_string(),
last_update: "2007-02-15 09:34:33 UTC".to_string(),
};
let actor2 = Actor {
actor_id: 2,
first_name: "Jane".to_string(),
last_name: "Doe".to_string(),
last_update: "2008-02-15 09:34:33 UTC".to_string(),
};
let stream_name = StreamName::new_default(project_id.clone(), dataset_id.clone(), table_id.clone());
let trace_id = "test_client".to_string();
let rows: &[Actor] = &[actor1, actor2];
let max_size = 9 * 1024 * 1024; let num_append_rows_calls = call_append_rows(
&mut client,
&table_descriptor,
&stream_name,
trace_id.clone(),
rows,
max_size,
)
.await?;
assert_eq!(num_append_rows_calls, 1);
let max_size = 50; let num_append_rows_calls =
call_append_rows(&mut client, &table_descriptor, &stream_name, trace_id, rows, max_size).await?;
assert_eq!(num_append_rows_calls, 2);
Ok(())
}
async fn call_append_rows(
client: &mut Client,
table_descriptor: &TableDescriptor,
stream_name: &StreamName,
trace_id: String,
mut rows: &[Actor],
max_size: usize,
) -> Result<u8, Box<dyn std::error::Error>> {
let mut num_append_rows_calls = 0;
loop {
let (encoded_rows, num_processed) = StorageApi::create_rows(table_descriptor, rows, max_size);
let mut streaming = client
.storage_mut()
.append_rows(stream_name, encoded_rows, trace_id.clone())
.await?;
num_append_rows_calls += 1;
while let Some(resp) = streaming.next().await {
let resp = resp?;
println!("response: {resp:#?}");
}
if num_processed == rows.len() {
break;
}
rows = &rows[num_processed..];
}
Ok(num_append_rows_calls)
}
}