@@ -156,22 +156,37 @@ void check_results(
156156 const std::vector<T>& C,
157157 const std::vector<T>& C_ref)
158158{
159- float err = 0 .f ;
159+ const float absolute = 1e-4f ;
160+
161+ float maxErr = 0 .f ;
162+ int errorCount = 0 ;
163+
160164 for (size_t m = 0 ; m < M; m++) {
161165 for (size_t n = 0 ; n < N; n++) {
162166 auto index = m * N + n;
163- auto localErr = std::fabs (C[index] - C_ref[index]) /
164- std::max (std::fabs (C[index]),
165- std::fabs (C_ref[index]));
166- err = std::max (localErr, err);
167- if (localErr >= threshold) {
168- std::cerr << " Error at m = " << m << " , n = " << n
169- << " : (local error " << localErr << " ): Wanted "
170- << C_ref[index] << " , got " << C[index] << std::endl;
171- return ;
167+ float got = static_cast <float >(C[index]);
168+ float want = static_cast <float >(C_ref[index]);
169+ float localErr = std::fabs (got - want);
170+ float localThreshold = absolute + threshold * std::fabs (want);
171+
172+ maxErr = std::max (localErr, maxErr);
173+ if (localErr > localThreshold) {
174+ if (errorCount < 1 ) {
175+ std::cerr << " Error at m = " << m << " , n = " << n
176+ << " : (abs error " << localErr << " , threshold "
177+ << localThreshold << " ): Wanted " << want
178+ << " , got " << got << std::endl;
179+ }
180+ ++errorCount;
172181 }
173182 }
174183 }
184+
185+ if (errorCount > 0 ) {
186+ std::cerr << " FAILED: " << errorCount << " of " << M * N
187+ << " elements exceeded tolerance. Max abs error: "
188+ << maxErr << std::endl;
189+ }
175190}
176191
177192static float hw_time (cl::Event& event)
0 commit comments