Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions infini_train/src/nn/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,20 @@

namespace infini_train::nn::init {
namespace {
static std::random_device rd;
static std::mt19937 gen(rd());
constexpr int kRandomSeed = 42;

// FIXME: RNG design is incomplete.
//
// Current implementation lacks:
// - unified Generator abstraction
// - global default generator and seed control
// - reproducible / clonable RNG state
//
// TODO:
// - introduce Generator interface and backend impl
// - add default generator management (per device)
// - refactor random ops to consume Generator
static std::mt19937 gen(kRandomSeed);
} // namespace

std::shared_ptr<Tensor> Normal(const std::shared_ptr<Tensor> &tensor, float mean, float std,
Expand All @@ -34,7 +46,7 @@ std::shared_ptr<Tensor> Normal(const std::shared_ptr<Tensor> &tensor, float mean
#ifdef USE_OMP
#pragma omp parallel
{
std::mt19937 local_gen(std::random_device{}() + omp_get_thread_num());
std::mt19937 local_gen(kRandomSeed + omp_get_thread_num());
std::normal_distribution<float> local_dis(mean, std);
#pragma omp for
for (int i = 0; i < buffer.size(); ++i) {
Expand Down Expand Up @@ -126,7 +138,7 @@ std::shared_ptr<Tensor> Uniform(const std::shared_ptr<Tensor> &tensor, float a,
#ifdef USE_OMP
#pragma omp parallel
{
std::mt19937 local_gen(std::random_device{}() + omp_get_thread_num());
std::mt19937 local_gen(kRandomSeed + omp_get_thread_num());
std::uniform_real_distribution<float> local_dis(a, b);
#pragma omp for
for (int i = 0; i < buffer.size(); ++i) {
Expand Down
Loading