Skip to content

Commit

Permalink
Refactor the rls to remove the names and be outside the table schema
Browse files Browse the repository at this point in the history
  • Loading branch information
mamcx committed Oct 8, 2024
1 parent 526d7ad commit 44aec4f
Show file tree
Hide file tree
Showing 16 changed files with 149 additions and 243 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,16 @@ use super::{
state_view::{Iter, IterByColRange, ScanIterByColRange, StateView},
tx_state::{DeleteTable, IndexIdMap, RemovedIndexIdSet, TxState},
};
use crate::db::datastore::system_tables::{
StRowLevelSecurityRow, ST_ROW_LEVEL_SECURITY_ID, ST_ROW_LEVEL_SECURITY_IDX, ST_ROW_LEVEL_SECURITY_NAME,
};
use crate::{
db::{
datastore::{
system_tables::{
system_tables, StColumnRow, StConstraintData, StConstraintRow, StIndexAlgorithm, StIndexRow,
StSequenceRow, StTableFields, StTableRow, SystemTable, ST_CLIENT_ID, ST_CLIENT_IDX, ST_COLUMN_ID,
ST_COLUMN_IDX, ST_COLUMN_NAME, ST_CONSTRAINT_ID, ST_CONSTRAINT_IDX, ST_CONSTRAINT_NAME, ST_INDEX_ID,
ST_INDEX_IDX, ST_INDEX_NAME, ST_MODULE_ID, ST_MODULE_IDX, ST_RESERVED_SEQUENCE_RANGE, ST_SCHEDULED_ID,
ST_SCHEDULED_IDX, ST_SEQUENCE_ID, ST_SEQUENCE_IDX, ST_SEQUENCE_NAME, ST_TABLE_ID, ST_TABLE_IDX,
ST_VAR_ID, ST_VAR_IDX,
ST_INDEX_IDX, ST_INDEX_NAME, ST_MODULE_ID, ST_MODULE_IDX, ST_RESERVED_SEQUENCE_RANGE,
ST_ROW_LEVEL_SECURITY_ID, ST_ROW_LEVEL_SECURITY_IDX, ST_SCHEDULED_ID, ST_SCHEDULED_IDX, ST_SEQUENCE_ID,
ST_SEQUENCE_IDX, ST_SEQUENCE_NAME, ST_TABLE_ID, ST_TABLE_IDX, ST_VAR_ID, ST_VAR_IDX,
},
traits::TxData,
},
Expand Down Expand Up @@ -228,23 +225,7 @@ impl CommittedState {

self.create_table(ST_SCHEDULED_ID, schemas[ST_SCHEDULED_IDX].clone());

// Insert the rls into `st_row_level_security`
let (st_rls, blob_store) =
self.get_table_and_blob_store_or_create(ST_ROW_LEVEL_SECURITY_ID, &schemas[ST_ROW_LEVEL_SECURITY_IDX]);
for rls in ref_schemas.iter().flat_map(|x| &x.row_level_security).cloned() {
let row = StRowLevelSecurityRow {
table_id: rls.table_id,
row_level_security_id: rls.row_level_security_id,
row_level_security_name: rls.row_level_security_name,
sql: rls.sql,
};
let row = ProductValue::from(row);
// Insert the meta-row into the in-memory ST_ROW_LEVEL_SECURITY.
// If the row is already there, no-op.
ignore_duplicate_insert_error(st_rls.insert(blob_store, &row))?;
// Increment row count for st_row_level_security.
with_label_values(ST_ROW_LEVEL_SECURITY_ID, ST_ROW_LEVEL_SECURITY_NAME).inc();
}
self.create_table(ST_ROW_LEVEL_SECURITY_ID, schemas[ST_ROW_LEVEL_SECURITY_IDX].clone());

// IMPORTANT: It is crucial that the `st_sequences` table is created last

Expand Down
67 changes: 28 additions & 39 deletions crates/core/src/db/datastore/locking_tx_datastore/datastore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,13 @@ impl MutTxDatastore for Locking {
tx.drop_row_level_security(ctx, row_level_security_policy_id)
}

fn row_level_security_id_from_name(
fn row_level_security_for_table_id(
&self,
tx: &Self::MutTx,
row_level_security_name: &str,
) -> crate::db::datastore::Result<Option<RowLevelSecurityId>> {
table_id: TableId,
) -> crate::db::datastore::Result<Vec<RowLevelSecuritySchema>> {
let ctx = &ExecutionContext::internal(self.database_address);
tx.row_level_security_id_from_name(ctx, row_level_security_name)
tx.row_level_security_for_table_id(ctx, table_id)
}

fn iter_mut_tx<'a>(
Expand Down Expand Up @@ -1061,15 +1061,6 @@ mod tests {
.sorted_by_key(|x| x.index_id)
.collect::<Vec<_>>())
}

pub fn scan_st_row_level_security(&self) -> Result<Vec<StRowLevelSecurityRow>> {
Ok(self
.db
.iter(self.ctx, ST_ROW_LEVEL_SECURITY_ID)?
.map(|row| StRowLevelSecurityRow::try_from(row).unwrap())
.sorted_by_key(|x| x.row_level_security_id)
.collect::<Vec<_>>())
}
}

fn u32_str_u32(a: u32, b: &str, c: u32) -> ProductValue {
Expand Down Expand Up @@ -1229,15 +1220,13 @@ mod tests {
struct RowLevelRow<'a> {
id: u32,
table: u32,
name: &'a str,
sql: &'a str,
}
impl From<RowLevelRow<'_>> for StRowLevelSecurityRow {
fn from(value: RowLevelRow<'_>) -> Self {
Self {
row_level_security_id: value.id.into(),
table_id: value.table.into(),
row_level_security_name: value.name.into(),
sql: value.sql.into(),
}
}
Expand Down Expand Up @@ -1290,7 +1279,6 @@ mod tests {
}),
},
],
vec![],
vec![SequenceSchema {
sequence_id: SequenceId::SENTINEL,
table_id: TableId::SENTINEL,
Expand Down Expand Up @@ -1326,7 +1314,6 @@ mod tests {
ConstraintRow { constraint_id: seq_start, table_id: table, unique_columns: col(0), constraint_name: "id_constraint" },
ConstraintRow { constraint_id: seq_start + 1, table_id: table, unique_columns: col(1), constraint_name: "name_constraint" }
]),
vec![],
map_array([
SequenceRow { id: seq_start, table, col_pos: 0, name: "id_sequence", start: 1 }
]),
Expand Down Expand Up @@ -1431,9 +1418,8 @@ mod tests {
ColRow { table: ST_SCHEDULED_ID.into(), pos: 3, name: "schedule_name", ty: AlgebraicType::String },

ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 0, name: "row_level_security_id", ty: RowLevelSecurityId::get_type() },
ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 1, name: "row_level_security_name", ty: AlgebraicType::String },
ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 2, name: "table_id", ty: TableId::get_type() },
ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 3, name: "sql", ty: AlgebraicType::String },
ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 1, name: "table_id", ty: TableId::get_type() },
ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 2, name: "sql", ty: AlgebraicType::String },
]));
#[rustfmt::skip]
assert_eq!(query.scan_st_indexes()?, map_array([
Expand All @@ -1447,8 +1433,9 @@ mod tests {
IndexRow { id: 8, table: ST_VAR_ID.into(), col: col(0), name: "idx_st_var_name_unique", },
IndexRow { id: 9, table: ST_SCHEDULED_ID.into(), col: col(0), name: "idx_st_scheduled_schedule_id_unique", },
IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "idx_st_scheduled_table_id_unique", },
IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "idx_st_row_level_security_row_level_security_id_unique" },
IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_row_level_security_name_unique"},
IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "idx_st_row_level_security_row_level_security_id_unique"},
IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_btree_table_id"},
IndexRow { id: 13, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(2), name: "idx_st_row_level_security_sql_unique"},
]));
let start = FIRST_NON_SYSTEM_ID as i128;
#[rustfmt::skip]
Expand Down Expand Up @@ -1479,7 +1466,7 @@ mod tests {
ConstraintRow { constraint_id: 9, table_id: ST_SCHEDULED_ID.into(), unique_columns: col(0), constraint_name: "ct_st_scheduled_schedule_id_unique" },
ConstraintRow { constraint_id: 10, table_id: ST_SCHEDULED_ID.into(), unique_columns: col(1), constraint_name: "ct_st_scheduled_table_id_unique" },
ConstraintRow { constraint_id: 11, table_id: ST_ROW_LEVEL_SECURITY_ID.into(), unique_columns: col(0), constraint_name: "ct_st_row_level_security_row_level_security_id_unique" },
ConstraintRow { constraint_id: 12, table_id: ST_ROW_LEVEL_SECURITY_ID.into(), unique_columns: col(1), constraint_name: "ct_st_row_level_security_row_level_security_name_unique" },
ConstraintRow { constraint_id: 12, table_id: ST_ROW_LEVEL_SECURITY_ID.into(), unique_columns: col(2), constraint_name: "ct_st_row_level_security_sql_unique" },
]));

// Verify we get back the tables correctly with the proper ids...
Expand Down Expand Up @@ -1892,7 +1879,8 @@ mod tests {
IndexRow { id: 9, table: ST_SCHEDULED_ID.into(), col: col(0), name: "idx_st_scheduled_schedule_id_unique", },
IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "idx_st_scheduled_table_id_unique", },
IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "idx_st_row_level_security_row_level_security_id_unique" },
IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_row_level_security_name_unique"},
IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_btree_table_id"},
IndexRow { id: 13, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(2), name: "idx_st_row_level_security_sql_unique"},
IndexRow { id: seq_start, table: FIRST_NON_SYSTEM_ID, col: col(0), name: "id_idx", },
IndexRow { id: seq_start + 1, table: FIRST_NON_SYSTEM_ID, col: col(1), name: "name_idx", },
IndexRow { id: seq_start + 2, table: FIRST_NON_SYSTEM_ID, col: col(2), name: "age_idx", },
Expand Down Expand Up @@ -1947,7 +1935,8 @@ mod tests {
IndexRow { id: 9, table: ST_SCHEDULED_ID.into(), col: col(0), name: "idx_st_scheduled_schedule_id_unique", },
IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "idx_st_scheduled_table_id_unique", },
IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "idx_st_row_level_security_row_level_security_id_unique" },
IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_row_level_security_name_unique"},
IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_btree_table_id"},
IndexRow { id: 13, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(2), name: "idx_st_row_level_security_sql_unique"},
IndexRow { id: seq_start , table: FIRST_NON_SYSTEM_ID, col: col(0), name: "id_idx" },
IndexRow { id: seq_start + 1, table: FIRST_NON_SYSTEM_ID, col: col(1), name: "name_idx" },
IndexRow { id: seq_start + 2, table: FIRST_NON_SYSTEM_ID, col: col(2), name: "age_idx" },
Expand Down Expand Up @@ -2003,7 +1992,8 @@ mod tests {
IndexRow { id: 9, table: ST_SCHEDULED_ID.into(), col: col(0), name: "idx_st_scheduled_schedule_id_unique", },
IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "idx_st_scheduled_table_id_unique", },
IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "idx_st_row_level_security_row_level_security_id_unique" },
IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_row_level_security_name_unique"},
IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_btree_table_id"},
IndexRow { id: 13, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(2), name: "idx_st_row_level_security_sql_unique"},
IndexRow { id: seq_start, table: FIRST_NON_SYSTEM_ID, col: col(0), name: "id_idx" },
IndexRow { id: seq_start + 1, table: FIRST_NON_SYSTEM_ID, col: col(1), name: "name_idx" },
].map(Into::into));
Expand Down Expand Up @@ -2099,24 +2089,23 @@ mod tests {

let rls = RowLevelSecuritySchema {
row_level_security_id: RowLevelSecurityId::SENTINEL,
row_level_security_name: "rls_foo".into(),
sql: "SELECT * FROM bar".into(),
table_id,
};
let ctx = ExecutionContext::default();
let rls_id = datastore.create_row_level_security_mut_tx(&mut tx, table_id, rls.clone())?;

let id = datastore.create_row_level_security_mut_tx(&mut tx, table_id, rls)?;
let query = query_st_tables(&ctx, &tx);

#[rustfmt::skip]
assert_eq!(query.scan_st_row_level_security()?, [
RowLevelRow { id:id.into(), table: table_id.into(), name: "rls_foo", sql: "SELECT * FROM bar" },
].map(Into::into));
let result = datastore.row_level_security_for_table_id(&tx, table_id)?;
assert_eq!(
result,
vec![RowLevelSecuritySchema {
row_level_security_id: rls_id,
sql: "SELECT * FROM bar".into(),
table_id,
}]
);

let id = datastore.row_level_security_id_from_name(&tx, "rls_foo")?.unwrap();
datastore.drop_row_level_security_mut_tx(&mut tx, id)?;
let query = query_st_tables(&ctx, &tx);
assert_eq!(query.scan_st_row_level_security()?, []);
datastore.drop_row_level_security_mut_tx(&mut tx, rls_id)?;
assert_eq!(datastore.row_level_security_for_table_id(&tx, table_id)?, []);

Ok(())
}
Expand Down
57 changes: 22 additions & 35 deletions crates/core/src/db/datastore/locking_tx_datastore/mut_tx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ impl MutTxId {
/// Requires:
/// - `row_level_security_schema.row_level_security_id == RowLevelSecurityId::SENTINEL`
/// - `row_level_security_schema.table_id != TableId::SENTINEL`
/// - `row_level_security_schema.row_level_security_name` must not be used for any other database entity.
/// - `row_level_security_schema.sql` must be unique.
///
/// Ensures:
///
Expand Down Expand Up @@ -982,19 +982,14 @@ impl MutTxId {
.into());
}

log::trace!(
"ROW LEVEL SECURITY CREATING: {} for table: {}",
row_level_security_schema.row_level_security_name,
table_id
);
log::trace!("ROW LEVEL SECURITY CREATING for table: {}", table_id);

// Insert the row into st_row_level_security
// NOTE: Because st_row_level_security has a unique index on security_name, this will
// fail if already exists.
let row = StRowLevelSecurityRow {
table_id,
row_level_security_id: RowLevelSecurityId::SENTINEL,
row_level_security_name: row_level_security_schema.row_level_security_name.clone(),
sql: row_level_security_schema.sql.clone(),
};

Expand All @@ -1003,12 +998,9 @@ impl MutTxId {
let existed = matches!(row.1, RowRefInsertion::Existed(_));

// Add the row level security to the transaction's insert table.
let (table, ..) = self.get_or_create_insert_table_mut(table_id)?;
self.get_or_create_insert_table_mut(table_id)?;
row_level_security_schema.row_level_security_id = row_level_security_id;

// This won't clone-write when creating a table but likely to otherwise.
table.with_mut_schema(|s| s.update_row_level_security(row_level_security_schema));

if existed {
log::trace!("ROW LEVEL SECURITY ALREADY EXISTS: {row_level_security_id}");
} else {
Expand All @@ -1018,6 +1010,25 @@ impl MutTxId {
Ok(row_level_security_id)
}

pub fn row_level_security_for_table_id(
&self,
ctx: &ExecutionContext,
table_id: TableId,
) -> Result<Vec<RowLevelSecuritySchema>> {
Ok(self
.iter_by_col_eq(
ctx,
ST_ROW_LEVEL_SECURITY_ID,
StRowLevelSecurityFields::TableId,
&table_id.into(),
)?
.map(|row| {
let row = StRowLevelSecurityRow::try_from(row).unwrap();
row.into()
})
.collect())
}

pub fn drop_row_level_security(
&mut self,
ctx: &ExecutionContext,
Expand All @@ -1035,35 +1046,11 @@ impl MutTxId {
.ok_or_else(|| {
TableError::IdNotFound(SystemTable::st_row_level_security, row_level_security_policy_id.into())
})?;
let table_id = st_rls_ref.read_col(StRowLevelSecurityFields::TableId)?;
self.delete(ST_ROW_LEVEL_SECURITY_ID, st_rls_ref.pointer())?;

// Remove rls in transaction's insert table.
let (table, ..) = self.get_or_create_insert_table_mut(table_id)?;
// This likely will do a clone-write as over time?
// The schema might have found other referents.
table.with_mut_schema(|s| s.remove_row_level_security(row_level_security_policy_id));

Ok(())
}

pub fn row_level_security_id_from_name(
&self,
ctx: &ExecutionContext,
row_level_security_name: &str,
) -> Result<Option<RowLevelSecurityId>> {
self.iter_by_col_eq(
ctx,
ST_ROW_LEVEL_SECURITY_ID,
StRowLevelSecurityFields::SecurityName,
&<Box<str>>::from(row_level_security_name).into(),
)
.map(|mut iter| {
iter.next()
.map(|row| row.read_col(StRowLevelSecurityFields::SecurityId).unwrap())
})
}

// TODO(perf, deep-integration):
// When all of [`Table::read_row`], [`RowRef::new`], [`CommittedState::get`]
// and [`TxState::get`] become unsafe,
Expand Down
15 changes: 0 additions & 15 deletions crates/core/src/db/datastore/locking_tx_datastore/state_view.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::{
committed_state::CommittedIndexIter, committed_state::CommittedState, datastore::Result, tx_state::TxState,
};
use crate::db::datastore::system_tables::{StRowLevelSecurityFields, StRowLevelSecurityRow, ST_ROW_LEVEL_SECURITY_ID};
use crate::{
db::datastore::system_tables::{
StColumnFields, StColumnRow, StConstraintFields, StConstraintRow, StIndexFields, StIndexRow, StScheduledFields,
Expand Down Expand Up @@ -124,26 +123,12 @@ pub trait StateView {
})
.transpose()?;

let row_level_security = self
.iter_by_col_eq(
ctx,
ST_ROW_LEVEL_SECURITY_ID,
StRowLevelSecurityFields::TableId,
value_eq,
)?
.map(|row| {
let row = StRowLevelSecurityRow::try_from(row)?;
Ok(row.into())
})
.collect::<Result<Vec<_>>>()?;

Ok(TableSchema::new(
table_id,
table_name,
columns,
indexes,
constraints,
row_level_security,
sequences,
table_type,
table_access,
Expand Down
Loading

0 comments on commit 44aec4f

Please sign in to comment.