Skip to content

Commit cdee85b

Browse files
committed
fix error calculation
1 parent 2ae7f9b commit cdee85b

2 files changed

Lines changed: 50 additions & 20 deletions

File tree

samples/20_matrixexperiments-bf16/main.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,22 +156,37 @@ void check_results(
156156
const std::vector<T>& C,
157157
const std::vector<T>& C_ref)
158158
{
159-
float err = 0.f;
159+
const float absolute = 1e-4f;
160+
161+
float maxErr = 0.f;
162+
int errorCount = 0;
163+
160164
for (size_t m = 0; m < M; m++) {
161165
for (size_t n = 0; n < N; n++) {
162166
auto index = m * N + n;
163-
auto localErr = std::fabs(C[index] - C_ref[index]) /
164-
std::max(std::fabs(C[index]),
165-
std::fabs(C_ref[index]));
166-
err = std::max(localErr, err);
167-
if (localErr >= threshold) {
168-
std::cerr << "Error at m = " << m << ", n = " << n
169-
<< ": (local error " << localErr << "): Wanted "
170-
<< C_ref[index] << ", got " << C[index] << std::endl;
171-
return;
167+
float got = static_cast<float>(C[index]);
168+
float want = static_cast<float>(C_ref[index]);
169+
float localErr = std::fabs(got - want);
170+
float localThreshold = absolute + threshold * std::fabs(want);
171+
172+
maxErr = std::max(localErr, maxErr);
173+
if (localErr > localThreshold) {
174+
if (errorCount < 1) {
175+
std::cerr << "Error at m = " << m << ", n = " << n
176+
<< ": (abs error " << localErr << ", threshold "
177+
<< localThreshold << "): Wanted " << want
178+
<< ", got " << got << std::endl;
179+
}
180+
++errorCount;
172181
}
173182
}
174183
}
184+
185+
if (errorCount > 0) {
186+
std::cerr << "FAILED: " << errorCount << " of " << M * N
187+
<< " elements exceeded tolerance. Max abs error: "
188+
<< maxErr << std::endl;
189+
}
175190
}
176191

177192
static float hw_time(cl::Event& event)

samples/20_matrixexperiments-tf32/main.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -154,22 +154,37 @@ void check_results(
154154
const std::vector<T>& C,
155155
const std::vector<T>& C_ref)
156156
{
157-
float err = 0.f;
157+
const float absolute = 1e-4f;
158+
159+
float maxErr = 0.f;
160+
int errorCount = 0;
161+
158162
for (size_t m = 0; m < M; m++) {
159163
for (size_t n = 0; n < N; n++) {
160164
auto index = m * N + n;
161-
auto localErr = std::fabs(C[index] - C_ref[index]) /
162-
std::max(std::fabs(C[index]),
163-
std::fabs(C_ref[index]));
164-
err = std::max(localErr, err);
165-
if (localErr >= threshold) {
166-
std::cerr << "Error at m = " << m << ", n = " << n
167-
<< ": (local error " << localErr << "): Wanted "
168-
<< C_ref[index] << ", got " << C[index] << std::endl;
169-
return;
165+
float got = static_cast<float>(C[index]);
166+
float want = static_cast<float>(C_ref[index]);
167+
float localErr = std::fabs(got - want);
168+
float localThreshold = absolute + threshold * std::fabs(want);
169+
170+
maxErr = std::max(localErr, maxErr);
171+
if (localErr > localThreshold) {
172+
if (errorCount < 1) {
173+
std::cerr << "Error at m = " << m << ", n = " << n
174+
<< ": (abs error " << localErr << ", threshold "
175+
<< localThreshold << "): Wanted " << want
176+
<< ", got " << got << std::endl;
177+
}
178+
++errorCount;
170179
}
171180
}
172181
}
182+
183+
if (errorCount > 0) {
184+
std::cerr << "FAILED: " << errorCount << " of " << M * N
185+
<< " elements exceeded tolerance. Max abs error: "
186+
<< maxErr << std::endl;
187+
}
173188
}
174189

175190
static float hw_time(cl::Event& event)

0 commit comments

Comments
 (0)