728 x 90

Jak (poprawnie) zsumować milion floatów

Jak (poprawnie) zsumować milion floatów

Problem wydaje się banalny i każdy kto chociaż trochę programował powinien sobie z tym poradzić w kilku linijkach… często nie zdając sobie sprawy z tego, że wchodzi na pole minowe.


Mając na przykład tablicę float t[1000000], implementacja algorytmu sumowania może wyglądać tak:

float sum = 0;
for (int i = 0; i < 1000000; ++i) {
  sum += t[i];
}

Na pierwszy rzut oka widać, że algorytm jest poprawny, ale okazuje się, że dla miliona liczb zmiennoprzecinkowych implementacja jest fatalna, a wynik będzie błędny. Każdy programista wie, że obliczenia na liczbach zmiennoprzecinkowych obarczone są pewnym błędem. Jednak wiele osób nie zdaje sobie sprawy z tego jak duży błąd wyprodukuje powyższy program. Nawet przy milionie rozsądnych liczb z zakresu 0.1 do 1.0, błąd pojawi się nie tylko w części ułamkowej, ale będziemy go liczyć w setkach lub tysiącach! Żeby się o tym przekonać zerkniemy na prosty przykład. Potem zobaczymy jak sobie z tym problemem poradzić. Niżej będę przytaczał fragmenty kodu w języku C, ale wszystkie wymienione problemy i rozwiązania w takim samym stopniu dotyczą programów w językach Java, C++, C# i innych.

By nie pisać w kółko 1000000, niech od teraz N = 1000000. Nie ma znaczenia, że N wynosi dokładnie milion, ważne jest to, że jest to spora liczba.
Wypełnimy tablicę float t[N] pewną stałą. Wybrałem liczbę float C = 0.53125f. Wybór jest nieprzypadkowy, bo jest to jedna z garstki liczb, które w typie float reprezentowane są bez żadnego błędu. Jednak prezentowany efekt będzie występował dla większości liczb.

Dzięki temu, że wszystkie liczby w tablicy są takie same, to możemy policzyć na palcach, że suma miliona liczb równych 0.53125 powinna wynieść 531250.

#include <stdio.h>

void main() {
  const int N = 1000000;
  const float C = 0.53125f;
  float t[N];

  for (int i = 0; i < N; ++i) {
    t[i] = C;
  }

  float sum = 0;
  for (int i = 0; i < N; ++i) {
    sum += t[i];
  }

  printf("sum: %f\n", sum);
  printf("C*N: %f\n", C*N);
  printf("Diff: %f\n", C*N - sum);
}

Program wypisuje:

sum: 530840.500000
C*N: 531250.000000
Diff: 409.500000

Czyli dostaliśmy 530840, a chcieliśmy 531250… błąd wynosi ponad 400. Wielkość błędu w dużej mierze zależy od tego jakie liczby sumujemy, ale w większości przypadków będzie wysoki. Np. podmieniając powyżej stałą C na 1.0625, błąd wynosi ponad 800.

Jest źle, ale czasami zdarza się, że błąd znika. Na przykład gdy ustalimy stałą C na 2, to błąd wynosi 0.
Jednak gdy zwiększymy N do 20000000 to błąd wyniesie aż 6445568. Powinniśmy dostać sumę równą 40000000, a dostajemy 33554432. Myśleliśmy, że dla 2 wszystko jest ok, ale okazało się, że jest jeszcze gorzej.

Widać, że mamy poważny problem i nie możemy tego tak zostawić.

Zerknijmy jeszcze raz na naszą implementację sumowania:

float sum = 0;
for (int i = 0; i < N; ++i) {
  sum += t[i];
}

Program jest tak prosty i niewinny, że trudno wypatrzyć jakikolwiek błąd. Jedynym podejrzanym może być wnętrze pętli, bo tylko tam jest wykonywana jakaś praca, czyli: sum += t[i]. Jednocześnie wiemy, że operacje na liczbach zmiennoprzecinkowych obarczone są drobnym błędem, dlaczego więc u nas błąd jest aż taki duży?

Operacja sumy dwóch liczb zmiennoprzecinkowych rzeczywiście produkuje stosunkowo mały błąd, ale kluczem jest tutaj słowo stosunkowo; konkretnie błąd jest mały, ale tylko w stosunku do większej z dwóch sumowanych liczb. Czyli gdy np. mamy sumę 1000000 + 0.001 i dostaniemy w wyniku 1000000, czyli mylimy się o 0.001, to w stosunku do 1000000 popełniamy bardzo mały błąd (0.0000001%). Jednak w stosunku do liczby 0.001 jest to ogromny błąd (100%). Taką sytuację mamy właśnie w naszym programie. Zmienna sum na początku wynosi 0, ale potem rośnie i staje się bardzo duża w stosunku do sumowanych liczb; wraz z postępem pętli błąd jest coraz większy. Problem jest szczególnie dotkliwy właśnie przy sumowaniu. Mnożenie i dzielenie na floatach zachowuje się dużo lepiej. Jednak tutaj posługujemy się sumą. Dobra wiadomość jest taka, że o ile liczby mają podobną bezwzględną wartość to błąd sumowania jest bardzo mały. Spróbujmy to wykorzystać.

Sumowanie rekurencyjne

W tym momencie narzuca się pomysł: najpierw zsumujemy liczby rozłącznymi parami. Wyniki sum powinny być podobnej wielkości. Potem zsumujemy parami wyniki tych sumowań, parami wyniki wyników, itd. Dzięki temu przy założeniu, że liczby mają podobną wielkość, to na każdym etapie sumowania będziemy mieć liczby o podobnej wielkości.


Całę procedurę można zwięźle zawrzeć w implementacji rekurenycjnej korzystając z metody dziel i zwyciężaj. Przetestujmy następujący program:

#include <stdio.h>
#include <stdlib.h>

float sumrec(float *t, int N) {
  switch (N) {
    case 0: return 0;
    case 1: return t[0];
    case 2: return t[0] + t[1];
    default: return sumrec(t, N/2) + sumrec(t + N/2, N - N/2);
  }
}

void main() {
  const int N = 1000000;
  const float C = 0.53125f;
  float t[N];
  for (int i = 0; i < N; ++i) {
    t[i] = C;
  }

  float sum = sumrec(t, N);

  printf("sum: %f\n", sum);
  printf("C*N: %f\n", C*N);
  printf("Diff: %f\n", C*N - sum);
}

Wszystkie wcześniej testowane przypadki zwracają błąd równy 0. Czyli sukces? W rzeczywistości nie zawsze błąd będzie tak mały, ale zawsze dużo mniejszy niż przy sumowaniu w pętli for.

A co ze złożonością? Dla prostszego rachunku przyjmijmy, że N=2n, najpierw sumujemy liczby na najniższym poziomie i sumujemy rozłączne pary, czyli mamy 2n-1sumowań. Na wyższym poziomie 2n-2, itd. W sumie wykonalibyśmy 2n-1 + 2n-2 + … + 1 = 2n – 1 sumowań, czyli dokładnie tyle co przy algorytmie w pętli for.

Mimo tej samej złożoności tracimy trochę na rzeczywistej wydajności; ciągłe wywoływanie funkcji i przerzucanie wartości z miejsca na miejsce zawsze trochę kosztuje, ale wcale nie będzie tak źle. Pobieżne testy wydajnościowe na moim laptopie, przy kompilacji gcc z optymalizacjami -O3 wykazały, że dla naszego N wersja rekurencyjna zajęła tylko około 1.6 razy więcej czasu niż pierwsza wersja. Konkretnie dostałem czasy 1.59ms dla wersji rekurencyjnej kontra 0.98ms dla zwykłej sumy w pętli for.

Co ciekawe program ten będzie odwiedzał elementy tablicy w takiej samej kolejności jak algorytm z pętlą for, dzięki czemu ma on taką samą szansę na wykorzystanie pamięci cache procesora. Ten konkretny program da się też zorganizować w taki sposób, żeby korzystał z instrukcji wektorowych SSE. Operacje zmiennoprzecinkowe SSE zwykle są obarczone jeszcze większym błędem niż te w FPU, dlatego nasz algorytm zachowujący lepszą dokładność będzie miał w tym przypadku duże znaczenie.

Proponuję potestować ten algorytm na różnych przykładach. Okazuje się, że zachowuje się bardzo dobrze. Można np. spróbować zsumować jednostajnie rosnący ciąg N liczb od A do B. Suma powinna wynieść (A+B)*N/2, dzięki czemu możemy obliczyć popełniany błąd. Zwykle w stosunku do wielkości wyniku będzie bardzo mały lub równy zero. Można też spróbować zsumować małe liczby i wstawić do tablicy w kilku miejscach większe liczby. Okaże się, że tylko w skrajnych przypadkach program zwróci duży błąd. Np. gdy suma już dwóch sąsiednich liczb z tablicy jest obarczona dużym błędem.

Algorytm Kahana

Okazuje się, że można sumować jeszcze lepiej: bez użycia rekurencji i dokładniej. Dokonamy tego za pomocą algorytmu sumowania wymyślonego przez Williama Kahana, głównego projektanta standardu IEEE_754-1985, który stanowił podstawę stworzenia obecnie stosowanej reprezentacji liczb zmiennoprzecinkowych.

Najpierw zerknijmy na implementację, a potem rozłożymy go na części pierwsze.

float sum_kahan(float *t, int N) {
  float sum = 0.0;
  float err = 0.0;
  for (int i = 0; i < N; ++i) {
    float y = t[i] - err;
    float temp = sum + y;
    err = (temp - sum) - y;
    sum = temp;
  }
  return sum;
}

Najpierw upewnijmy się, że algorytm zwraca to co powinien. Na początku sum = 0. Potem w każdym kroku oblicza:

y = t[i] - err;
temp = sum + y;
err = (temp - sum) - y;
sum = temp;

Podstawmy y w drugiej i trzeciej linijce, dostaniemy:

temp = sum + t[i] - err;
err = temp - sum - t[i] + err;
sum = temp;

Teraz podstawy temp, tam gdzie się da:

err = sum + t[i] - err - sum - t[i] + err = 0
sum = sum + t[i] - err

Okazało się, że err = 0, czyli

sum = sum + t[i]

Czyli rzeczywiście, patrząc na program symbolicznie dostaniemy sumę elementów w tablicy. Ale skoro err wyniosło 0, to po co te wszystkie operacje powyżej? Oczywiście nie możemy tak patrzeć na program, bo w rzeczywistości operujemy na liczbach typu float i każda operacja jest obarczona błędem. Zmienna err ma taką nazwę nie bez powodu. W trakcie wykonania będzie się w niej akumulował błąd jaki popełniamy w każdym kroku sumowania. Pozostałe operacje są po to żeby po każdym sumowaniu ten błąd wyłuskać.

Zerknijmy najpierw na to jak wygląda jeden krok pętli, gdy err = 0:
1) y = t[i] - err,
czyli y = t[i], bo założyliśmy, że err = 0;
2) temp = sum + y
czyli w temp mamy sumę powiększoną o t[i]. Wiemy już, że tutaj zwykle powstaje duży błąd;
3) err = (temp - sum) - y;
(temp - sum) powinno dać y, czyli aktualnie dodawany element t[i]. Jeżeli nie został tam popełniony błąd to tyle właśnie dostaniemy, ale jeżeli podczas sumowania powstał błąd to dostaniemy y + błąd. Od tego odejmujemy aktualną wartość y i zostaje sam błąd sumowania, który zapisujemy do err;
4) sum = temp
wstawiamy aktualny stan sumy. Wiemy, że suma różni się od prawdziwego wyniku o wartość err, ale w kolejnych krokach spróbujemy odjąć ten błąd od całej sumy.

Jeżeli to nie był ostatni przebieg pętli to w kolejnym kroku zaczynamy od tego, że y = t[i] - err. Czyli kolejna liczba, którą będziemy dodawać w korku sumowania temp = sum + y, będzie pomniejszona o aktualnie wykonywany błąd. W pierwszym kroku jest właśnie próba redukcji poprzednio popełnionych błędów. Jeżeli nie da się zredukować sumy o cały błąd to brakujący fragment zostanie uwzględniony w kroku 3), i próba redukcji tego błędu zostanie powtórzona w kolejnym kroku.

Jest to bardzo dobry algorytm. Jedyną jego wadą jest to, że jest trochę wolniejszy od pierwotnej wersji sumowania, co nie powinno być zaskoczeniem, ponieważ w każdym kroku wykonuje on dużo więcej obliczeń. Jasnym jest, że złożoność obliczeniowa jest taka sama, ale w praktyce czas liczymy w milisekundach: u mnie sumowanie kahana zajęło 4ms kontra 0.98ms dla wersji pierwotnej, czyli wersja Kahana jest u mnie około 4 razy wolniejsza niż wersja pierwotna.

Pomysł akumulowania popełnianego błędu można spotkać w wielu innych algorytmach, nie tylko operujących na liczbach typu float. Sztandarowym przykładem jest algorytm Bresenhama, służacy do rysowania odcinków o dowolnym nachyleniu przy użyciu jedynie liczb całkowitych. Osobiście, ostatnio zdarzyło mi się użyć podobnego zabiegu przy synchronizacji czasowej kilku Arduino. Musiałem tam szybko obliczać średnie kroczące korzystając z arytmetyki stałoprzecinkowej o małej precyzji. Technika kumulowania popełnionego błędu w osobnej zmiennej i okresowego kumulowania błędu pomogła mi małym kosztem znacznie zwiększyć dokładność obliczeń.

Bardzo często przy okazji omawiania algorytmu Kahana zwraca się uwagę na to, że niektóre kompilatory mogą spojrzeć na wnętrze pętli for, potraktować obliczenia symbolicznie i zredukować wszystko do jednego wyrażenia sum += t[i], o którym słyszeliśmy tyle złego. Dlatego trzeba się upewnić, że nasz kompilator nie robi tak agresywnych optymalizacji. Większość kompilatorów zdaje sobie sprawę z tego, że nawet kolejność wykonywania dodawania zmiennych typu float wpływa na wynik obliczeń. Dlatego zwykle nie natrafimy na ten problem o ile sami intencjonalnie nie włączymy takiej optymalizacji. Przy kompilatorach gcc i clang nie spotkałem się z tym problemem. W gcc nawet przy wysokim poziomie optymalizacji, w tym z przełącznikiem -ffast-math, powyższa implementacja nadal zachowywała się poprawnie. Mimo wszystko w takiej sytuacji, zawsze warto wykonywać testy i upewniać się, że program działa tak jak byśmy tego chcieli.

Tak czy inaczej…

Algorytm Kahana nie zamyka tematu, lecz tylko pokazuje którędy droga. Omówione zagadnienie dotyczy nie tylko zwykłego sumowania tablic, ale każdej sytuacji, w której wykonujemy redukcję wielu liczb. W szczególności przy obliczaniu różnego rodzaju statystyk, np. na potrzeby metody monte-carlo, gdzie liczba zbieranych elementów często idzie w grube miliardy. Podobne problemy z dodawaniem małej liczby do dużej mogą być bardzo poważne przy implementacji schematów numerycznych we wszelkich symulacjach procesów fizycznych. Trzeba wówczas szczególnie uważać by przez przypadek nie zafałszować wyników.

Wiele osób pewno zauważy, że w omawianych przykładach problem mogliśmy też ominąć po prostu przechodząc z obliczeniami na liczby o podwójnej precyzji, czyli typ double. I rzeczywiście w pewnych przypadkach taki manewr wystarczy. Jednak jest to zabieg krótkowzroczny. Błędy obliczeń zmiennoprzecinkowych tak samo dotyczą typu double co float, tylko nie manifestują się tak szybko. Jeżeli nie mamy pewności co do ilości oraz wielkości sumowanych liczb, to mając metodę na polepszenie dokładności nierozsądnym by było z niej nie skorzystać. Dodatkowo warto pamiętać, że wiele sprzętów, szczególnie urządzenia mobilne i mikrokontrolery, sprzętowo nie obsługuje obliczeń o podwójnej precyzji. Może się okazać, że akumulowanie sum w double nie tylko będzie wolniejsze, a będzie też zużywało istotnie więcej energii. Problem jest jeszcze większy gdy urządzenie obsługuje liczby zmiennoprzecinkowe tylko w trybie emulacji. Wówczas użycie podwójnej precyzji może okazać się zupełnie niepraktyczne.

Najlepszym sposobem na ustrzeżenie się przed problemami z floatami jest po prostu sukcesywne poszerzanie wiedzy o obliczeniach na liczbach zmiennoprzecinkowych. Tam gdzie tylko możemy powinniśmy też stosować gotowe i sprawdzone funkcje, o ile tylko ufamy, że ich autorzy zadbali o stabilność działań na floatach. Trzeba też zawsze kontrolować dla jakich liczb nasze implementacje będą działać wystarczająco dokładnie, a dla jakich będą popełniać zbyt duży błąd. Dodatkowym sposobem zabezpieczenia się przed takimi problemami jest też ciągłe wymyślanie i wykonywanie nowych testów, które mogą czasem wskazać na niespodziewane przypadki graniczne, dla których obliczenia przestają być stabilne.

Na koniec polecam obejrzeć wykład z CppCon 2015 pt. “Demystifying Floating Point“, w którym John Farrier podsumowuje niuanse pracy z floatami; oraz trochę lżejszy filmik z kanału Computerphile, w którym Tom Scott opowiada o tym, dlaczego liczby zmiennoprzecinkowe są takie jakie są. Warto też przejrzeć bardzo dobry artykuł na angielskiej Wikipedii o liczbach zmiennoprzecinkowych, a w szczególności sekcję o dokładności obliczeń.

Źródło obrazu tytułowego: Courtesy NASA/JPL-Caltech.
source: https://www.nasa.gov/feature/jpl/when-computers-were-human

12 comments

Leave a Comment

Your email address will not be published. Required fields are marked with *

Cancel reply

12 Comments

  • Krzysztof Rykaczewski
    26 czerwca 2017, 14:58

    Świetny artykuł!

    REPLY
  • dev
    5 lipca 2017, 00:27

    A kto używa floatów ? 😀
    Double i po problemie !

    REPLY
    • kaczus@dev
      5 lipca 2017, 08:45

      Float od Double różni się tylko precyzją, problem i tak wystąpi, tylko będzie minimalnie mniej uciążliwy.

      REPLY
  • Artur
    6 lipca 2017, 18:47

    Świetny post, gratulacje.
    Próbowałem odtworzyć powyższy przypadek w kodzie PHP oraz JavaScript ale w obu przypadkach zwykłe sumowanie zwracało poprawny wynik. Dlaczego tak się dzieje ? Czy te języki maja jakąś lepszą natywną obsługę float-ów ?

    REPLY
    • x@Artur
      5 sierpnia 2017, 11:22

      PHP i Javascript są to dynamicznie typowane języki. Tak więc gdy przyjmujesz liczbę 0.53125 to dynamiczne typowanie traktuje ją nie jako float a jako double dlatego błąd nie występuje w tym przypadku. Ale gdy zamiast liczby 0.53125 przyjmiesz jakąś inną z mniejsza precyzją np 0.1 to błąd już się pojawia.

      REPLY
  • Ric
    7 lipca 2017, 08:18

    1. posortować wszystkie liczby
    2. dodać do siebie dwie najmniejsze
    3. zrobić binary search i wynik wstawić w rosnącą listę pozostałych liczb
    4. repeat until death

    REPLY
  • Michał
    13 lipca 2017, 19:00

    BigDecimal lub specjalistyczne klasy typu Money rozwiązują problem.

    REPLY

Inne artykuły