import binascii
import hashlib
import math
import multiprocessing
import os
import random
import time
from dataclasses import dataclass
from datetime import timedelta
from queue import Empty, Full
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import backoff
import torch
from Crypto.Hash import keccak
from rich import console as rich_console
from rich import status as rich_status
import cybertensor
from cybertensor import __console__ as console
from cybertensor.utils._register_cuda import solve_cuda
from cybertensor.utils.formatting import get_human_readable, millify
from cybertensor.wallet import Wallet
class CUDAException(Exception):
pass
def _hex_bytes_to_u8_list(hex_bytes: bytes):
hex_chunks = [int(hex_bytes[i: i + 2], 16) for i in range(0, len(hex_bytes), 2)]
return hex_chunks
def _create_seal_hash(block_and_hotkey_hash_bytes: bytes, nonce: int) -> bytes:
nonce_bytes = binascii.hexlify(nonce.to_bytes(8, "little"))
pre_seal = nonce_bytes + binascii.hexlify(block_and_hotkey_hash_bytes)[:64]
seal_sh256 = hashlib.sha256(bytearray(_hex_bytes_to_u8_list(pre_seal))).digest()
kec = keccak.new(digest_bits=256)
seal = kec.update(seal_sh256).digest()
return seal
def _seal_meets_difficulty(seal: bytes, difficulty: int, limit: int):
seal_number = int.from_bytes(seal, "big")
product = seal_number * difficulty
return product < limit
@dataclass
class POWSolution:
nonce: int
block_number: int
difficulty: int
seal: bytes
def is_stale(self, cwtensor: "cybertensor.cwtensor") -> bool:
return self.block_number < cwtensor.get_current_block() - 3
class _SolverBase(multiprocessing.Process):
proc_num: int
num_proc: int
update_interval: int
finished_queue: multiprocessing.Queue
solution_queue: multiprocessing.Queue
newBlockEvent: multiprocessing.Event
stopEvent: multiprocessing.Event
hotkey_bytes: bytes
curr_block: multiprocessing.Array
curr_block_num: multiprocessing.Value
curr_diff: multiprocessing.Array
check_block: multiprocessing.Lock
limit: int
def __init__(
self,
proc_num,
num_proc,
update_interval,
finished_queue,
solution_queue,
stopEvent,
curr_block,
curr_block_num,
curr_diff,
check_block,
limit,
):
multiprocessing.Process.__init__(self, daemon=True)
self.proc_num = proc_num
self.num_proc = num_proc
self.update_interval = update_interval
self.finished_queue = finished_queue
self.solution_queue = solution_queue
self.newBlockEvent = multiprocessing.Event()
self.newBlockEvent.clear()
self.curr_block = curr_block
self.curr_block_num = curr_block_num
self.curr_diff = curr_diff
self.check_block = check_block
self.stopEvent = stopEvent
self.limit = limit
def run(self):
raise NotImplementedError("_SolverBase is an abstract class")
@staticmethod
def create_shared_memory() -> (
Tuple[multiprocessing.Array, multiprocessing.Value, multiprocessing.Array]
):
curr_block = multiprocessing.Array("h", 32, lock=True) curr_block_num = multiprocessing.Value("i", 0, lock=True) curr_diff = multiprocessing.Array("Q", [0, 0], lock=True)
return curr_block, curr_block_num, curr_diff
class _Solver(_SolverBase):
def run(self):
block_number: int
block_and_hotkey_hash_bytes: bytes
block_difficulty: int
nonce_limit = int(math.pow(2, 64)) - 1
nonce_start = random.randint(0, nonce_limit)
nonce_end = nonce_start + self.update_interval
while not self.stopEvent.is_set():
if self.newBlockEvent.is_set():
with self.check_block:
block_number = self.curr_block_num.value
block_and_hotkey_hash_bytes = bytes(self.curr_block)
block_difficulty = _registration_diff_unpack(self.curr_diff)
self.newBlockEvent.clear()
solution = _solve_for_nonce_block(
nonce_start,
nonce_end,
block_and_hotkey_hash_bytes,
block_difficulty,
self.limit,
block_number,
)
if solution is not None:
self.solution_queue.put(solution)
try:
self.finished_queue.put_nowait(self.proc_num)
except Full:
pass
nonce_start = random.randint(0, nonce_limit)
nonce_start = nonce_start % nonce_limit
nonce_end = nonce_start + self.update_interval
class _CUDASolver(_SolverBase):
dev_id: int
TPB: int
def __init__(
self,
proc_num,
num_proc,
update_interval,
finished_queue,
solution_queue,
stopEvent,
curr_block,
curr_block_num,
curr_diff,
check_block,
limit,
dev_id: int,
TPB: int,
):
super().__init__(
proc_num,
num_proc,
update_interval,
finished_queue,
solution_queue,
stopEvent,
curr_block,
curr_block_num,
curr_diff,
check_block,
limit,
)
self.dev_id = dev_id
self.TPB = TPB
def run(self):
block_number: int = 0 block_and_hotkey_hash_bytes: bytes = b"0" * 32 block_difficulty: int = int(math.pow(2, 64)) - 1 nonce_limit = int(math.pow(2, 64)) - 1
nonce_start = random.randint(0, nonce_limit)
while not self.stopEvent.is_set():
if self.newBlockEvent.is_set():
with self.check_block:
block_number = self.curr_block_num.value
block_and_hotkey_hash_bytes = bytes(self.curr_block)
block_difficulty = _registration_diff_unpack(self.curr_diff)
self.newBlockEvent.clear()
solution = _solve_for_nonce_block_cuda(
nonce_start,
self.update_interval,
block_and_hotkey_hash_bytes,
block_difficulty,
self.limit,
block_number,
self.dev_id,
self.TPB,
)
if solution is not None:
self.solution_queue.put(solution)
try:
self.finished_queue.put(self.proc_num)
except Full:
pass
nonce_start += self.update_interval * self.TPB
nonce_start = nonce_start % nonce_limit
def _solve_for_nonce_block_cuda(
nonce_start: int,
update_interval: int,
block_and_hotkey_hash_bytes: bytes,
difficulty: int,
limit: int,
block_number: int,
dev_id: int,
TPB: int,
) -> Optional[POWSolution]:
solution, seal = solve_cuda(
nonce_start,
update_interval,
TPB,
block_and_hotkey_hash_bytes,
difficulty,
limit,
dev_id,
)
if solution != -1:
return POWSolution(solution, block_number, difficulty, seal)
return None
def _solve_for_nonce_block(
nonce_start: int,
nonce_end: int,
block_and_hotkey_hash_bytes: bytes,
difficulty: int,
limit: int,
block_number: int,
) -> Optional[POWSolution]:
for nonce in range(nonce_start, nonce_end):
seal = _create_seal_hash(block_and_hotkey_hash_bytes, nonce)
if _seal_meets_difficulty(seal, difficulty, limit):
return POWSolution(nonce, block_number, difficulty, seal)
return None
def _registration_diff_unpack(packed_diff: multiprocessing.Array) -> int:
return int(packed_diff[0] << 32 | packed_diff[1])
def _registration_diff_pack(diff: int, packed_diff: multiprocessing.Array):
packed_diff[0] = diff >> 32
packed_diff[1] = diff & 0xFFFFFFFF
def _hash_block_with_hotkey(block_bytes: bytes, hotkey_bytes: bytes) -> bytes:
kec1 = keccak.new(digest_bits=256)
kec1 = kec1.update(bytearray(hotkey_bytes))
hotkey_hash_bytes = kec1.digest()
kec2 = keccak.new(digest_bits=256)
kec2 = kec2.update(bytearray(block_bytes + hotkey_hash_bytes))
block_and_hotkey_hash_bytes = kec2.digest()
return block_and_hotkey_hash_bytes
def _update_curr_block(
curr_diff: multiprocessing.Array,
curr_block: multiprocessing.Array,
curr_block_num: multiprocessing.Value,
block_number: int,
block_bytes: bytes,
diff: int,
hotkey_bytes: bytes,
lock: multiprocessing.Lock,
):
with lock:
curr_block_num.value = block_number
block_and_hotkey_hash_bytes = _hash_block_with_hotkey(block_bytes, hotkey_bytes)
for i in range(32):
curr_block[i] = block_and_hotkey_hash_bytes[i]
_registration_diff_pack(diff, curr_diff)
def get_cpu_count() -> int:
try:
return len(os.sched_getaffinity(0))
except AttributeError:
return os.cpu_count()
@dataclass
class RegistrationStatistics:
time_spent_total: float
rounds_total: int
time_average: float
time_spent: float
hash_rate_perpetual: float
hash_rate: float
difficulty: int
block_number: int
block_hash: bytes
class RegistrationStatisticsLogger:
console: rich_console.Console
status: Optional[rich_status.Status]
def __init__(
self, console: rich_console.Console, output_in_place: bool = True
) -> None:
self.console = console
if output_in_place:
self.status = self.console.status("Solving")
else:
self.status = None
def start(self) -> None:
if self.status is not None:
self.status.start()
def stop(self) -> None:
if self.status is not None:
self.status.stop()
def get_status_message(
cls, stats: RegistrationStatistics, verbose: bool = False
) -> str:
message = (
"Solving\n"
+ f"Time Spent (total): [bold white]{timedelta(seconds=stats.time_spent_total)}[/bold white]\n"
+ (
f"Time Spent This Round: {timedelta(seconds=stats.time_spent)}\n"
+ f"Time Spent Average: {timedelta(seconds=stats.time_average)}\n"
if verbose
else ""
)
+ f"Registration Difficulty: [bold white]{millify(stats.difficulty)}[/bold white]\n"
+ f"Iters (Inst/Perp): [bold white]{get_human_readable(stats.hash_rate, 'H')}/s / "
+ f"{get_human_readable(stats.hash_rate_perpetual, 'H')}/s[/bold white]\n"
+ f"Block Number: [bold white]{stats.block_number}[/bold white]\n"
+ f"Block Hash: [bold white]{stats.block_hash.encode('utf-8')}[/bold white]\n"
)
return message
def update(self, stats: RegistrationStatistics, verbose: bool = False) -> None:
if self.status is not None:
self.status.update(self.get_status_message(stats, verbose=verbose))
else:
self.console.log(self.get_status_message(stats, verbose=verbose))
def _solve_for_difficulty_fast(
cwtensor: "cybertensor.cwtensor",
wallet: "Wallet",
netuid: int,
output_in_place: bool = True,
num_processes: Optional[int] = None,
update_interval: Optional[int] = None,
n_samples: int = 10,
alpha_: float = 0.80,
log_verbose: bool = False,
) -> Optional[POWSolution]:
if num_processes is None:
num_processes = min(1, get_cpu_count())
if update_interval is None:
update_interval = 50_000
limit = int(math.pow(2, 256)) - 1
curr_block, curr_block_num, curr_diff = _Solver.create_shared_memory()
stopEvent = multiprocessing.Event()
stopEvent.clear()
solution_queue = multiprocessing.Queue()
finished_queues = [multiprocessing.Queue() for _ in range(num_processes)]
check_block = multiprocessing.Lock()
hotkey_bytes = (
wallet.coldkeypub.address.encode()
if netuid == -1
else wallet.hotkey.address.encode()
)
solvers = [
_Solver(
i,
num_processes,
update_interval,
finished_queues[i],
solution_queue,
stopEvent,
curr_block,
curr_block_num,
curr_diff,
check_block,
limit,
)
for i in range(num_processes)
]
block_number, difficulty, block_hash = _get_block_with_retry(
cwtensor=cwtensor, netuid=netuid
)
block_bytes = bytes.fromhex(block_hash[2:])
old_block_number = block_number
_update_curr_block(
curr_diff,
curr_block,
curr_block_num,
block_number,
block_bytes,
difficulty,
hotkey_bytes,
check_block,
)
for worker in solvers:
worker.newBlockEvent.set()
for worker in solvers:
worker.start()
start_time = time.time() time_last = start_time
curr_stats = RegistrationStatistics(
time_spent_total=0.0,
time_average=0.0,
rounds_total=0,
time_spent=0.0,
hash_rate_perpetual=0.0,
hash_rate=0.0,
difficulty=difficulty,
block_number=block_number,
block_hash=block_hash,
)
start_time_perpetual = time.time()
logger = RegistrationStatisticsLogger(console, output_in_place)
logger.start()
solution = None
hash_rates = [0] * n_samples weights = [alpha_**i for i in range(n_samples)]
while netuid == -1 or not cwtensor.is_hotkey_registered(
netuid=netuid, hotkey=wallet.hotkey.address
):
try:
solution = solution_queue.get(block=True, timeout=0.25)
if solution is not None:
break
except Empty:
pass
old_block_number = _check_for_newest_block_and_update(
cwtensor=cwtensor,
netuid=netuid,
hotkey_bytes=hotkey_bytes,
old_block_number=old_block_number,
curr_diff=curr_diff,
curr_block=curr_block,
curr_block_num=curr_block_num,
curr_stats=curr_stats,
update_curr_block=_update_curr_block,
check_block=check_block,
solvers=solvers,
)
num_time = 0
for finished_queue in finished_queues:
try:
proc_num = finished_queue.get(timeout=0.1)
num_time += 1
except Empty:
continue
time_now = time.time() time_since_last = time_now - time_last if num_time > 0 and time_since_last > 0.0:
hash_rate_ = (num_time * update_interval) / time_since_last
hash_rates.append(hash_rate_)
hash_rates.pop(0) curr_stats.hash_rate = sum(
[hash_rates[i] * weights[i] for i in range(n_samples)]
) / (sum(weights))
time_last = time_now
curr_stats.time_average = (
curr_stats.time_average * curr_stats.rounds_total
+ curr_stats.time_spent
) / (curr_stats.rounds_total + num_time)
curr_stats.rounds_total += num_time
curr_stats.time_spent = time_since_last
new_time_spent_total = time_now - start_time_perpetual
curr_stats.hash_rate_perpetual = (
curr_stats.rounds_total * update_interval
) / new_time_spent_total
curr_stats.time_spent_total = new_time_spent_total
logger.update(curr_stats, verbose=log_verbose)
stopEvent.set() logger.stop()
_terminate_workers_and_wait_for_exit(solvers)
return solution
@backoff.on_exception(backoff.constant, Exception, interval=1, max_tries=3)
def _get_block_with_retry(
cwtensor: "cybertensor.cwtensor", netuid: int
) -> Tuple[int, int, bytes]:
block_number = cwtensor.get_current_block()
difficulty = 1_000_000 if netuid == -1 else cwtensor.difficulty(netuid=netuid)
block_hash = cwtensor.get_block_hash(block_number)
if block_hash is None:
raise Exception(
"Network error. Could not connect to substrate to get block hash"
)
if difficulty is None:
raise ValueError("Chain error. Difficulty is None")
return block_number, difficulty, block_hash
class _UsingSpawnStartMethod:
def __init__(self, force: bool = False):
self._old_start_method = None
self._force = force
def __enter__(self):
self._old_start_method = multiprocessing.get_start_method(allow_none=True)
if self._old_start_method is None:
self._old_start_method = "spawn"
multiprocessing.set_start_method("spawn", force=self._force)
def __exit__(self, *args):
multiprocessing.set_start_method(self._old_start_method, force=True)
def _check_for_newest_block_and_update(
cwtensor: "cybertensor.cwtensor",
netuid: int,
old_block_number: int,
hotkey_bytes: bytes,
curr_diff: multiprocessing.Array,
curr_block: multiprocessing.Array,
curr_block_num: multiprocessing.Value,
update_curr_block: Callable,
check_block: "multiprocessing.Lock",
solvers: List[_Solver],
curr_stats: RegistrationStatistics,
) -> int:
block_number = cwtensor.get_current_block()
if block_number != old_block_number:
old_block_number = block_number
block_number, difficulty, block_hash = _get_block_with_retry(
cwtensor=cwtensor, netuid=netuid
)
block_bytes = bytes.fromhex(block_hash[2:])
update_curr_block(
curr_diff,
curr_block,
curr_block_num,
block_number,
block_bytes,
difficulty,
hotkey_bytes,
check_block,
)
for worker in solvers:
worker.newBlockEvent.set()
curr_stats.block_number = block_number
curr_stats.block_hash = block_hash
curr_stats.difficulty = difficulty
return old_block_number
def _solve_for_difficulty_fast_cuda(
cwtensor: "cybertensor.cwtensor",
wallet: "Wallet",
netuid: int,
output_in_place: bool = True,
update_interval: int = 50_000,
TPB: int = 512,
dev_id: Union[List[int], int] = 0,
n_samples: int = 10,
alpha_: float = 0.80,
log_verbose: bool = False,
) -> Optional[POWSolution]:
if isinstance(dev_id, int):
dev_id = [dev_id]
elif dev_id is None:
dev_id = [0]
if update_interval is None:
update_interval = 50_000
if not torch.cuda.is_available():
raise Exception("CUDA not available")
limit = int(math.pow(2, 256)) - 1
with _UsingSpawnStartMethod(force=True):
curr_block, curr_block_num, curr_diff = _CUDASolver.create_shared_memory()
num_processes = len(dev_id)
stopEvent = multiprocessing.Event()
stopEvent.clear()
solution_queue = multiprocessing.Queue()
finished_queues = [multiprocessing.Queue() for _ in range(num_processes)]
check_block = multiprocessing.Lock()
hotkey_bytes = wallet.hotkey.public_key
solvers = [
_CUDASolver(
i,
num_processes,
update_interval,
finished_queues[i],
solution_queue,
stopEvent,
curr_block,
curr_block_num,
curr_diff,
check_block,
limit,
dev_id[i],
TPB,
)
for i in range(num_processes)
]
block_number, difficulty, block_hash = _get_block_with_retry(
cwtensor=cwtensor, netuid=netuid
)
block_bytes = bytes.fromhex(block_hash[2:])
old_block_number = block_number
_update_curr_block(
curr_diff,
curr_block,
curr_block_num,
block_number,
block_bytes,
difficulty,
hotkey_bytes,
check_block,
)
for worker in solvers:
worker.newBlockEvent.set()
for worker in solvers:
worker.start()
start_time = time.time() time_last = start_time
curr_stats = RegistrationStatistics(
time_spent_total=0.0,
time_average=0.0,
rounds_total=0,
time_spent=0.0,
hash_rate_perpetual=0.0,
hash_rate=0.0, difficulty=difficulty,
block_number=block_number,
block_hash=block_hash,
)
start_time_perpetual = time.time()
logger = RegistrationStatisticsLogger(console, output_in_place)
logger.start()
hash_rates = [0] * n_samples weights = [alpha_**i for i in range(n_samples)]
solution = None
while netuid == -1 or not cwtensor.is_hotkey_registered(
netuid=netuid, hotkey=wallet.hotkey.address
):
try:
solution = solution_queue.get(block=True, timeout=0.15)
if solution is not None:
break
except Empty:
pass
old_block_number = _check_for_newest_block_and_update(
cwtensor=cwtensor,
netuid=netuid,
hotkey_bytes=hotkey_bytes,
curr_diff=curr_diff,
curr_block=curr_block,
curr_block_num=curr_block_num,
old_block_number=old_block_number,
curr_stats=curr_stats,
update_curr_block=_update_curr_block,
check_block=check_block,
solvers=solvers,
)
num_time = 0
for finished_queue in finished_queues:
try:
proc_num = finished_queue.get(timeout=0.1)
num_time += 1
except Empty:
continue
time_now = time.time() time_since_last = time_now - time_last if num_time > 0 and time_since_last > 0.0:
hash_rate_ = (num_time * TPB * update_interval) / time_since_last
hash_rates.append(hash_rate_)
hash_rates.pop(0) curr_stats.hash_rate = sum(
[hash_rates[i] * weights[i] for i in range(n_samples)]
) / (sum(weights))
time_last = time_now
curr_stats.time_average = (
curr_stats.time_average * curr_stats.rounds_total
+ curr_stats.time_spent
) / (curr_stats.rounds_total + num_time)
curr_stats.rounds_total += num_time
curr_stats.time_spent = time_since_last
new_time_spent_total = time_now - start_time_perpetual
curr_stats.hash_rate_perpetual = (
curr_stats.rounds_total * (TPB * update_interval)
) / new_time_spent_total
curr_stats.time_spent_total = new_time_spent_total
logger.update(curr_stats, verbose=log_verbose)
stopEvent.set() logger.stop()
_terminate_workers_and_wait_for_exit(solvers)
return solution
def _terminate_workers_and_wait_for_exit(
workers: List[multiprocessing.Process],
) -> None:
for worker in workers:
worker.terminate()
worker.join()
def create_pow(
cwtensor: "cybertensor.cwtensor",
wallet: Wallet,
netuid: int,
output_in_place: bool = True,
cuda: bool = False,
dev_id: Union[List[int], int] = 0,
tpb: int = 256,
num_processes: int = None,
update_interval: int = None,
log_verbose: bool = False,
) -> Optional[Dict[str, Any]]:
if netuid != -1:
if not cwtensor.subnet_exists(netuid=netuid):
raise ValueError(f"Subnet {netuid} does not exist")
if cuda:
solution: Optional[POWSolution] = _solve_for_difficulty_fast_cuda(
cwtensor,
wallet,
netuid=netuid,
output_in_place=output_in_place,
dev_id=dev_id,
TPB=tpb,
update_interval=update_interval,
log_verbose=log_verbose,
)
else:
solution: Optional[POWSolution] = _solve_for_difficulty_fast(
cwtensor,
wallet,
netuid=netuid,
output_in_place=output_in_place,
num_processes=num_processes,
update_interval=update_interval,
log_verbose=log_verbose,
)
return solution