use cosmwasm_std::{ensure, Api, DepsMut, Env, MessageInfo, Order, StdResult, Storage};
use crate::math::{check_vec_max_limited, vec_u16_max_upscale_to_u16};
use crate::root::{contains_invalid_root_uids, get_root_netuid, if_subnet_exist};
use crate::state::{MIN_ALLOWED_WEIGHTS, WEIGHTS, WEIGHTS_VERSION_KEY};
use crate::uids::{
get_subnetwork_n, get_uid_for_net_and_hotkey, is_hotkey_registered_on_network,
is_uid_exist_on_network,
};
use crate::utils::{
get_last_update_for_uid, get_max_weight_limit, get_validator_permit_for_uid,
get_weights_set_rate_limit, set_last_update_for_uid,
};
use crate::ContractError;
use cyber_std::Response;
pub fn do_set_weights(
deps: DepsMut,
env: Env,
info: MessageInfo,
netuid: u16,
uids: Vec<u16>,
values: Vec<u16>,
version_key: u64,
) -> Result<Response, ContractError> {
let hotkey = info.sender;
ensure!(
uids_match_values(&uids, &values),
ContractError::WeightVecNotEqualSize {}
);
ensure!(
if_subnet_exist(deps.storage, netuid),
ContractError::NetworkDoesNotExist {}
);
if netuid == get_root_netuid() {
ensure!(
!contains_invalid_root_uids(deps.storage, deps.api, &uids),
ContractError::InvalidUid {}
);
} else {
ensure!(
check_len_uids_within_allowed(deps.storage, netuid, &uids),
ContractError::TooManyUids {}
);
}
ensure!(
is_hotkey_registered_on_network(deps.storage, netuid, &hotkey),
ContractError::NotRegistered {}
);
ensure!(
check_version_key(deps.storage, deps.api, netuid, version_key),
ContractError::IncorrectNetworkVersionKey {}
);
let neuron_uid;
let net_neuron_uid = get_uid_for_net_and_hotkey(deps.storage, netuid, &hotkey);
ensure!(net_neuron_uid.is_ok(), ContractError::NotRegistered {});
neuron_uid = net_neuron_uid.unwrap();
let current_block: u64 = env.block.height;
ensure!(
check_rate_limit(deps.storage, netuid, neuron_uid, current_block),
ContractError::SettingWeightsTooFast {}
);
if netuid != get_root_netuid() {
ensure!(
check_validator_permit(deps.storage, netuid, neuron_uid, &uids, &values),
ContractError::NoValidatorPermit {}
);
}
ensure!(!has_duplicate_uids(&uids), ContractError::DuplicateUids {});
if netuid != get_root_netuid() {
ensure!(
!contains_invalid_uids(deps.storage, deps.api, netuid, &uids),
ContractError::InvalidUid {}
);
}
ensure!(
check_length(deps.storage, netuid, neuron_uid, &uids, &values),
ContractError::NotSettingEnoughWeights {}
);
let max_upscaled_weights: Vec<u16> = vec_u16_max_upscale_to_u16(&values);
ensure!(
max_weight_limited(
deps.storage,
netuid,
neuron_uid,
&uids,
&max_upscaled_weights
),
ContractError::MaxWeightExceeded {}
);
let mut zipped_weights: Vec<(u16, u16)> = vec![];
for (uid, val) in uids.iter().zip(max_upscaled_weights.iter()) {
zipped_weights.push((*uid, *val))
}
WEIGHTS.save(deps.storage, (netuid, neuron_uid), &zipped_weights)?;
set_last_update_for_uid(deps.storage, netuid, neuron_uid, current_block);
Ok(Response::default()
.add_attribute("active", "weights_set")
.add_attribute("netuid", format!("{}", netuid))
.add_attribute("neuron_uid", format!("{}", neuron_uid)))
}
pub fn check_version_key(
store: &dyn Storage,
_api: &dyn Api,
netuid: u16,
version_key: u64,
) -> bool {
let network_version_key: u64 = WEIGHTS_VERSION_KEY.load(store, netuid).unwrap();
return network_version_key.clone() == 0 || version_key >= network_version_key;
}
pub fn check_rate_limit(
store: &dyn Storage,
netuid: u16,
neuron_uid: u16,
current_block: u64,
) -> bool {
if is_uid_exist_on_network(store, netuid, neuron_uid) {
let last_set_weights: u64 = get_last_update_for_uid(store, netuid, neuron_uid);
if last_set_weights == 0 {
return true;
} let rate_limit = get_weights_set_rate_limit(store, netuid);
return current_block - last_set_weights >= rate_limit;
}
return false;
}
pub fn contains_invalid_uids(
store: &dyn Storage,
api: &dyn Api,
netuid: u16,
uids: &Vec<u16>,
) -> bool {
for uid in uids {
if !is_uid_exist_on_network(store, netuid, uid.clone()) {
api.debug(&format!(
"๐ก contains_invalid_uids ( netuid:{:?}, uid:{:?} does not exist on network. )",
netuid, uids
));
return true;
}
}
return false;
}
pub fn uids_match_values(uids: &Vec<u16>, values: &Vec<u16>) -> bool {
return uids.len() == values.len();
}
pub fn has_duplicate_uids(items: &Vec<u16>) -> bool {
let mut parsed: Vec<u16> = Vec::new();
for item in items {
if parsed.contains(&item) {
return true;
}
parsed.push(item.clone());
}
return false;
}
pub fn check_validator_permit(
store: &dyn Storage,
netuid: u16,
uid: u16,
uids: &Vec<u16>,
weights: &Vec<u16>,
) -> bool {
if is_self_weight(uid.clone(), uids, weights) {
return true;
}
get_validator_permit_for_uid(store, netuid, uid)
}
pub fn check_length(
store: &dyn Storage,
netuid: u16,
uid: u16,
uids: &Vec<u16>,
weights: &Vec<u16>,
) -> bool {
let subnet_n: usize = get_subnetwork_n(store, netuid.clone()) as usize;
let min_allowed_length: usize = MIN_ALLOWED_WEIGHTS.load(store, netuid).unwrap() as usize;
let min_allowed: usize = {
if subnet_n.clone() < min_allowed_length.clone() {
subnet_n
} else {
min_allowed_length
}
};
if netuid != 0 && is_self_weight(uid, uids, weights) {
return true;
}
if weights.len() >= min_allowed {
return true;
}
return false;
}
#[cfg(test)]
pub fn normalize_weights(mut weights: Vec<u16>) -> Vec<u16> {
let sum: u64 = weights.iter().map(|x| *x as u64).sum();
if sum.clone() == 0 {
return weights;
}
weights.iter_mut().for_each(|x| {
*x = (*x as u64 * u16::max_value() as u64 / sum) as u16;
});
return weights;
}
pub fn max_weight_limited(
store: &dyn Storage,
netuid: u16,
uid: u16,
uids: &Vec<u16>,
weights: &Vec<u16>,
) -> bool {
if is_self_weight(uid, uids, weights) {
return true;
}
let max_weight_limit: u16 = get_max_weight_limit(store, netuid);
if max_weight_limit == u16::MAX {
return true;
}
check_vec_max_limited(weights, max_weight_limit)
}
pub fn is_self_weight(uid: u16, uids: &Vec<u16>, weights: &Vec<u16>) -> bool {
if weights.len() != 1 {
return false;
}
if uid != uids[0] {
return false;
}
return true;
}
pub fn check_len_uids_within_allowed(store: &dyn Storage, netuid: u16, uids: &Vec<u16>) -> bool {
let subnetwork_n: u16 = get_subnetwork_n(store, netuid);
return uids.len() <= subnetwork_n as usize;
}
pub fn get_network_weights(store: &dyn Storage, netuid: u16) -> StdResult<Vec<Vec<u16>>> {
let n: usize = get_subnetwork_n(store, netuid) as usize;
let mut weights: Vec<Vec<u16>> = vec![vec![0; n]; n];
for item in WEIGHTS
.prefix(netuid)
.range(store, None, None, Order::Ascending)
{
let (uid_i, weights_i) = item.unwrap();
for (uid_j, weight_ij) in weights_i.iter() {
weights[uid_i as usize][*uid_j as usize] = *weight_ij;
}
}
Ok(weights)
}
pub fn get_network_weights_sparse(
store: &dyn Storage,
netuid: u16,
) -> StdResult<Vec<Vec<(u16, u16)>>> {
let n: usize = get_subnetwork_n(store, netuid) as usize;
let mut weights: Vec<Vec<(u16, u16)>> = vec![vec![]; n];
for item in WEIGHTS
.prefix(netuid)
.range(store, None, None, Order::Ascending)
{
let (uid_i, weights_i) = item.unwrap();
for (uid_j, weight_ij) in weights_i.iter() {
weights[uid_i as usize].push((*uid_j, *weight_ij));
}
}
Ok(weights)
}