JMK no matter what

PKU3740 Time Attack: gradual optimization

어제 자기 전에 IRC 에 가보니 사람들이 열심히 pku 문제 하나 를 잡고 숏코딩과 타임어택을 하고 있었다. 숏코딩은 whitespace 포함 가장 작은 소스코드로 문제를 푸는 것이고, 타임어택은 가장 짧은 시간 걸려서 문제를 푸는 것. 아기 코끼리 덤보마냥 귀를 팔락거리면서 말려서 결국 타임어택에 도전했다. 641ms 였던 프로그램을 16ms 로 낮추면서 가장 낮은 시간 중 하나 가 됨. pku 저지의 timing resolution 때문에 아마 이보다 낮은 시간은 puts("output file"); 아니면 불가능하지 않을까 생각됨.

비교적 가장 단순하고 무식한 코드에서 시작해서 bottleneck 을 없애고 알고리즘을 최적화하는 방식으로 진행했기 때문에, 나름대로 재미있을 것 같아서 각 단계별로 소스코드를 모아보았다.

문제

문제는 일단, H*W 크기의 binary 매트릭스가 주어질 때, row 들의 서브셋을 구했을 때 이 row 들에서 각 column 에 1이 한 개씩만 포함되어 있도록 하기다.

First Try

뭐든지 타임어택에 제일 중요한 건 일단 답이 맞는 솔루션을 내는 것이다. 그래서 일단 처음에 문제를 풀기로 하고, H <= 16, W <= 300 이기 때문에, 자연스럽게 탑코더에서 자주 쓰는 O(2^H * W) 솔루션이 떠올라서 이런 코드를 작성했다.

lang:cpp
#include<algorithm>
#include<cstring>
#include<cstdio>
using namespace std;
int h, w, d[16][300];

void backt(int here, int* u)
{
    if(here == h)
    {
        if(find(u, u+w, 0) == u+w)
            throw 1;
    }
    else
    {
        int mx = 0;
        for(int i = 0; i < w; ++i)
            mx = max(mx, u[i] += d[here][i]);
        if(mx == 1) backt(here+1, u);
        for(int i = 0; i < w; ++i)
            u[i] -= d[here][i];
        backt(here+1, u);
    }
}

int main()
{
    while(scanf("%d %d", &h, &w) == 2)
    {
        for(int i = 0; i < h; ++i)
            for(int j = 0; j < w; ++j)
                scanf("%d", &d[i][j]);
        try
        {
            int u[300];
            memset(u, 0, sizeof(u));
            backt(0, u);
            puts("It is impossible");
        }
        catch(int)
        {
            puts("Yes, I found it");
        }
    }
}

각 row 를 택할지 말지를 백트래킹으로 결정하고, 이미 선택한 row 와 1 이 겹치면 끝내는, 단순한 2^n 방식.

printf 가 느리니까 puts 쓰고.. 반환값 체크하기 귀찮아서 또 '답찾으면 throw' 패턴을 썼다. 이 시점에서 641ms. ㅡㅡ 다른 사람들은 200ms 대에서 최적화하고 있는 상황.

Second Try: Input optimization

이런 작은 프로그램에서는 I/O 가 병목현상이 되는 일이 흔히 있다. 특히 적절히 추상화된 I/O 들은 더 느리기 일쑤. (Scanner 나 cin/cout 의 속도는 악명이 높다...) scanf/printf 는 더 빠르다고 하지만, 스트링 단위 raw I/O 의 속도를 따라갈 수 있을 리 없다. 그래서 행렬 입력 부분을 gets() 로 바꾸고 매뉴얼하게 파싱했다. 1 이랑 0 밖에 없기 때문에 파싱은 매우 간단.

lang:cpp
    char buf[700];
    while(scanf("%d %d", &h, &w) == 2)
    {
        gets(buf);
        for(int i = 0; i < h; ++i)
        {
            gets(buf);
            for(int j = 0; j < w; ++j)
                d[i][j] = buf[j<<1] - '0';
        }

바뀐 부분만 표시. 이런 최적화를 하자 483ms 로 내려갔다.

Third Try: Conflict check precalc

자 제일 당연한 건 했고 이제 뭐하지? 하고 쭉 훑어보니 이 row 를 택할 수 있나 판단하는 부분이 눈에 들어온다.

lang:cpp
        int mx = 0;
        for(int i = 0; i < w; ++i)
            mx = max(mx, u[i] += d[here][i]);
        if(mx == 1) backt(here+1, u);
        for(int i = 0; i < w; ++i)
            u[i] -= d[here][i];

지금은 일단, 각 column 에다가 숫자를 전부 더해준 뒤, 이중의 최대값이 1 이면 재귀호출에 들어간다. (2 이상이면 들어갈 필요도 없다. 같은 column 에 1이 두개 이상 있으면 어차피 안되니까) 근데, 첫 column 에서 conflict 가 나도 끝까지 봐야 되나? 만약 2 이상이 등장하면 곧장 break 해버리면 안되나?

물론 안된다. 이후에 u[i] 에서 d[here][i] 를 빼줘야 하는데, 중간에 break 해버리면 나중에 복구하는 과정에서 아직 더하지도 않은 부분에서 d[here][i] 를 빼버리게 된다. .. 이걸 적절히 노가다로 구현할 수도 있지만 그러고 싶진 않고, 다시 생각해 보자.

두 개의 row 가 주어졌을 때, 이 row 들에 겹치는 column 이 있는지를 오프라인으로 확인한다고 하자. 그리고 conflict[i] 를 i번째 row 와 겹치는 column 이 있는 row 들의 비트마스크라고 하면, 지금까지 선택한 row 들의 비트마스크 selected 가 있을 때 selected & conf[here] == 0 이어야지만 이 row 를 선택할 수 있다. 애초에 for(i) 루프를 돌 필요가 없어지는 것.

이와 같은 변화를 구현했다. (실제로는, 지금까지 선택한 row 들에 의해 더이상 선택할 수 없게 되어버린 row 들의 비트마스크 - conf[here] 의 합집합 - 을 유지했다. 별다른 차이는 없음)

lang:cpp
#include<algorithm>
#include<cstring>
#include<cstdio>
using namespace std;
int h, w, d[16][300], has[300], conf[16];

void backt(int here, int* u, int cant)
{
    if(here == h)
    {
        if(find(u, u+w, 0) == u+w)
            throw 1;
    }
    else
    {
        if(!(cant & (1 << here)))
        {
            for(int i = 0; i < w; ++i) u[i] += d[here][i];
            backt(here+1, u, cant | conf[here]);
            for(int i = 0; i < w; ++i) u[i] -= d[here][i];
        }
        backt(here+1, u, cant);
    }
}

int main()
{
    char buf[700];
    while(scanf("%d %d", &h, &w) == 2)
    {
        gets(buf);
        memset(has, 0, sizeof(has));
        for(int i = 0; i < h; ++i)
        {
            gets(buf);
            for(int j = 0; j < w; ++j)
            {
                d[i][j] = buf[j<<1] - '0';
                if(d[i][j]) has[j] += (1 << i);
            }
        }
        memset(conf, 0, sizeof(conf));
        for(int i = 0; i < h; ++i)
        {
            for(int j = 0; j < w; ++j)
                if(d[i][j]) conf[i] |= has[j] - (1 << i);
        }
        try
        {
            int u[300];
            memset(u, 0, sizeof(u));
            backt(0, u, 0);
            puts("It is impossible");
        }
        catch(int)
        {
            puts("Yes, I found it");
        }
    }
}

소스코드가 많이 변해서 전문을 올린다. conf[] 계산하기 위한 부분이 재미있는데, 일일이 pairwise 로 계산할 필요가 없다. 우선 각 비트에 대해 해당 비트가 1 인 row 들의 비트마스크를 has[i] 에 모은다. 그리고, i번째 row 의 j번째 비트가 켜져 있다면, has[j] 에 있는 애들은 모두 i 와 겹친다. 스스로를 빼 주면 has[j] - (1<<i) 가 됨. conf[i] 는 이들의 합집합이다.

이렇게 최적화하니까 313ms 로 줄어들었다.

Fourth Try: Change Algo

그리고 나면 이제 최적화할 거리가 딱히 눈에 들어오지 않는다. 최후의 수단으로 u[] 배열을 int64 다섯 개로 바꿔야 하나.. 이딴 고민을 하게 된다. 코드레벨 최적화를 하는 사람들이 대개 제일 먼저 찾는 것이 i) 루프 언롤링 ii) 변수 할당 최적화 따위인데 이런 것은 컴파일러도 최적화를 잘해주기 때문에 그닥 효과가 없다. 대개의 경우 최적화에 가장 큰 효과가 있는 것은 항상 알고리즘 최적화다. 과연 이 알고리즘이 최적일까?

이제 외부에서 최적화할 부분은 비교적 적다고 보고, 아무래도 재귀호출이 지수 크기의 상태공간을 탐색하다 보니 bottleneck 이라고 보는 것이 맞다. 재귀호출을 최적화할 수 있는 방법으로 가장 대표적인 것은 i) 내부 연산을 precalc ii) 탐색 방향을 바꾸는 것이 있겠다. 내부 연산 미리 연산하기는 conf[] 계산하면서 미리 했고.. 탐색 방향을 바꾸는 것 (요건 다음에 기회가 되면 포스팅) 에 대해 생각해 보자.

지금은 각 row 에 대해 선택할지/말지를 결정하는 알고리즘이다. 이 경우 선택지의 최종 개수가 2^16 으로 bound 되기 때문에 일단 마음이 편하다. 하지만, 마지막 선택까지 다 하기 전에 모든 비트가 다 켜졌는지를 확인할 방도가 없다. 게다가 이걸 확인하려면 O(w) 를 돌아야 한다. (이 글 쓰다 보니, 각 row 마다 켜진 비트의 수를 미리 세 놓고, 이 수의 합을 유지하면 된단 생각이 들었는데.. 의외로 빨라지진 않는다). 게다가 이 탐색은 '무식하다'. 당연하게 두 row 가 겹쳐서 탐색을 중단하는 경우를 제외하면, 끝까지 가기 전에 아무 pruning 도 할 수 없다. 예를 들어, 지금 선택한 두 개의 row 와 겹치는 row 들을 모두 제외하면 어떤 column 에 켜진 비트가 하나도 안 남는다고 하면, 당연히 답이 없다. 그런데 이 알고리즘은 설마 하면서 모든 답을 뒤져본다. 이와 같은 일이 있는지를 별도의 루프를 써서 확인할 수 있지만, 재귀호출 안에 추가로 루프를 넣는 것은 자살행위.

그럼 방향을 바꿔서, row 에 대해 선택하는 게 아니라 각 column 에 대해서 어떤 row 를 써서 이 column 에 1 을 넣을까를 탐색하면 어떨까? 그러니까,

lang:cpp
void backtr(int bit, int* u, int cant) // bit is column number
{
  if bit == w throw 1; // search is finished
  if bit[u] > 0
     backtr(bit+1, u, cant); // no need to select row
  else
    for every row i
      if row i has the bit, and (1<<row)&cant equals zero,
        increment u
        backtr(bit+1, u, cant)
        decrement u
}

대략 이런 식으로.. 실제로 짜진 않아서 소스 코드는 없다. 대충 알아봐라.

그러면, w 번째 column 까지 가면 반드시 모든 비트가 켜졌다는 것을 알 수 있고, 지금까지 선택한 row 와 안 겹치는 row 중 현재 column 에 비트가 켜진 것이 없으면 그냥 곧장 그만둘 수도 있다. 와, 유식하다!

이 방법에서 탐색의 상한선은 얼마일까? 물론 제일 무식하게 O(16^300) 일수도 있다. --;; 물론 이 탐색 공간의 크기는 1.7e361 인데 이런 탐색은 요한 계시록에 등장하는 네 명의 기사가 등장하여 지구가 멸망하고 태양계가 멸망하고 전우주가 블랙홀에 빨려들때까지 계산되지 않는다. -_-;;

그런데, 실제로 row 는 16개밖에 되지 않기 때문에 선택은 최대 16번밖에 하지 않는다. 나머지 bit 에서는 i) 이 column 이 이미 켜져 있으므로 다음 column 로 진행 ii) 켤 수 있는 row 가 없으므로 더 탐색하지 않고 반환 이것밖에 하지 않는다. 따라서 실제 선택지의 수는 대략 O(1! + 2! + ... + 16!) 으로 줄어든다.

그런데, 더 자세히 생각해 보자. 16! 이라면, row 1 2 3 이 있을 때 이들을 각각 다른 순서로 선택할 수 있다는 말이다. 그런데 그럴 수 있을까? 1 2 3 이 서로 겹치지 않는다면, 0 번째 비트는 1 2 3 중 하나에서만 켜져 있다. 그러면, bit = 0 을 지날 때는 항상 그 row 만을 선택하게 되지 다른 row 를 선택할 수는 없다. 따라서 항상 같은 순서로밖에 선택할 수 없다. 그러면 16 row 를 전부 선택하는 방법은 한개밖에 없다. 이렇게 reduce 해보면, 결국 선택지의 수는 O(2^16) 이다. 원래 문제랑 똑같다는 이야기. 할렐루야!

그래서 구현해 보자. 실제로는 여기서 한 발짝 더 나가서 세 가지 최적화를 구현했다.

  1. 현재 선택한 row 들의 비트마스크를 유지한다. 그러면, has[bit] 와 row 의 교집합이 있는지만 보면 이 비트가 켜졌는지를 알 수 있으니 u[] 를 유지할 필요가 없어진다!!!
  2. 각 column 마다 켜진 비트가 하나밖에 없다면 그 row 를 무조건 선택하자.
  3. 어떤 column 에 켜진 비트가 없으면 무조건 실패.

이와 같은 최적화를 모두 구현하자 다음과 같은 코드가 되었다.

lang:cpp
#include<algorithm>
#include<cstring>
#include<cstdio>
using namespace std;
int h, w, d[16][300], has[300], conf[16];

void backt(int bit, int taken, int cant)
{
    while(bit < w && (has[bit] & taken)) ++bit;
    if(bit == w) throw 1;
    int cands = has[bit] & ~cant;
    if(cands == 0) return;
    for(int row = 0; row < h; ++row)
        if(cands & (1<<row))
        {
            if(conf[row] & taken) continue;
            backt(bit+1, taken | (1<<row), cant | conf[row]);
        }
}

int main()
{
    char buf[700];
    int u[300];
    while(scanf("%d %d", &h, &w) == 2)
    {
        gets(buf);
        memset(has, 0, sizeof(has));
        for(int i = 0; i < h; ++i)
        {
            gets(buf);
            for(int j = 0; j < w; ++j)
            {
                d[i][j] = buf[j<<1] - '0';
                if(d[i][j])
                    has[j] |= (1<<i);
            }
        }
        memset(conf, 0, sizeof(conf));
        for(int bit = 0; bit < w; ++bit)
        {
            for(int row = 0; row < h; ++row)
                if(d[row][bit])
                    conf[row] |= (has[bit] - (1<<row));
        }
        int taken = 0, cant = 0;
        for(int bit = 0; bit < w; ++bit)
            if((has[bit] & (has[bit]-1)) == 0)
            {
                taken |= has[bit];
                cant |= conf[__builtin_ctz(has[bit])];
            }
        try
        {
            if((taken & cant) == 0 && find(has, has+w, 0) == has+w)
            {
                backt(0, taken, cant);
            }
            puts("It is impossible");
        }
        catch(int)
        {
            puts("Yes, I found it");
        }
    }
}

서브밋해보니 무려 63ms. 야 신난다. u 를 관리할 필요가 없는 게 굉장히 컸던 거 같다. 역시 가장 빠른 최적화는 알고리즘 레벨 최적화에서 온다. 아직도 최적화할 꺼리는 많으니 계속 최적화해보자.

Fifth Try: Shift Parameter Optimization

이제는 진짜로 특정 루프가 실행 시간에 큰 비중을 차지하지 않는다. 따라서 특정 루프를 최적화하거나 변경함으로써 프로그램 전체의 성능을 크게 끌어올릴 수 없다. 그럼 특정 instruction 은 어떨까. 비트마스크를 쓰는 프로그램에서는 시프트 오퍼레이터를 당연히 많이 쓰기 마련이다. 물론 이건 무지하게 빠르지만, 혹시나 해서 이 값을 배열에 넣어서 캐싱을 해보았다. 어차피 16개밖에 안되니 캐시도 잘 맞지 않을까 하는 기대를 해봤다.

lang:cpp
const int sh[16] = { 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768 };

이러고 1<<a 따위의 코드를 다 sh[a] 로 바꿨다. 큰 기대는 안 하고 내봤는데 47ms. 의외로 오래 걸리나 보다. 앗싸 좋구나. 이쯤 되니 어디까지 갈 수 있나 궁금해서 계속 최적화를 하게 된다.

Sixth Try: eliminating d[]

u 배열 북키핑을 하지 않으니 이제 더이상 d 배열 (원래 주어진 행렬) 을 쓸 필요가 없다. has[] 배열만 있으면 충분하다. 어차피 16x300 바이트 배열따위 크기만 해서 액세스 효율도 떨어진다. 그래서, d 를 지워버리고 ㅋ입력부의 코드를 다음과 같이 바꾼다.

lang:cpp
        for(int i = 0; i < h; ++i)
        {
            gets(buf);
            for(int j = 0; j < w; ++j)
                if(buf[j*2] == '1')
                    has[j] |= sh[i];
        }

        memset(conf, 0, sizeof(conf));

        for(int bit = 0; bit < w; ++bit)
        {
            for(int row = 0; row < h; ++row)
                if(has[bit]&sh[row])
                    conf[row] |= (has[bit] - sh[row]);
        }

32ms 로 줄어든다.

Seventh Try: whatever..

이젠 프로그램에서 사용하는 메모리도 한줌밖에 안된다. 이거 다해봐야 2kb 도 안되고 더이상 최적화할 꺼리가 없다. 그래서 설마하면서 안하던걸 다해보기로 한다. 혹시 exception handling 에 부하가 있을 까 해서 함수 호출도 성공 여부를 반환하기로 바꿨다. 변수 선언도 다 위로 빼서 내부에서 선언하지 않기로 한다. 텅 빈 column 이 있는지 확인을 find() 알고리즘을 써서 하고 있는데 위의 루프에 합칠 수 있을 것 같아서 합쳤다. 이런 발악을 거치자 다음과 같은 코드가 된다.

lang:cpp
#include<algorithm>
#include<cstring>
#include<cstdio>
using namespace std;
int h, w, has[300], conf[16];
const int sh[16] = { 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768 };

bool backt(int bit, int taken, int cant)
{
    while(bit < w && (has[bit] & taken)) ++bit;
    if(bit == w) return true;
    int cands = has[bit] & ~cant;
    if(cands == 0) return false;
    for(int row = 0; row < h; ++row)
        if(cands & sh[row])
        {
            if(conf[row] & taken) continue;
            if(backt(bit+1, taken | sh[row], cant | conf[row])) return true;
        }
    return false;
}

int main()
{
    char buf[700];
    int i, j, bit, row, taken, cant;
    while(scanf("%d %d", &h, &w) == 2)
    {
        gets(buf);
        memset(has, 0, sizeof(has));
        for(i = 0; i < h; ++i)
        {
            gets(buf);
            for(j = 0; j < w; ++j)
                if(buf[j*2] == '1')
                    has[j] |= sh[i];
        }

        memset(conf, 0, sizeof(conf));
        for(bit = 0; bit < w; ++bit)
        {
            for(row = 0; row < h; ++row)
                if(has[bit]&sh[row])
                    conf[row] |= (has[bit] - sh[row]);
        }
        taken = 0, cant = 0;
        for(bit = 0; bit < w; ++bit)
        {
            if(has[bit] == 0)
            {
                taken = cant = 1;
                break;
            }
            if((has[bit] & (has[bit]-1)) == 0)
            {
                taken |= has[bit];
                cant |= conf[__builtin_ctz(has[bit])];
            }
        }
        if((taken & cant) == 0 && backt(0, taken, cant))
        {
            puts("Yes, I found it");
        }
        else
        {
            puts("It is impossible");
        }
    }
}

결과는 16ms. 할렐루야! 여기선 실제 최적화보다도 그냥 online judge 의 타이밍이 이렇게 미세하게까지 정확하지 않아서 오차 때문에 +- 있는 게 효과를 본 것 같기도 하다. 어쨌건 이래서 어제의 타임어택은 끝을 맺었다.

Analysis

이거이 속도 향상을 나타낸 그래프이다. 실제로 slope 를 보나 절대적인 향상을 보나, 알고리즘 최적화가 가장 큰 역할을 했다는 것을 알 수 있다. (precalc 도 알고리즘 최적화라고 보고)

결론: 프로그램을 빨리 돌리고 싶을 때는

  • 일단 O2 를 켜라
  • 이상한 루프 언롤링 직접 하지 말고 precalc 나 알고리즘을 바꿔봐라

가 되겠다.

아 길었다. 이제 일해야지 (응?...)

2009-11-12 23:37:40 | JM | /logs/ | 2 Comments
Being
2009-11-13 00:48:04
오오 경험치가 1 상승!
LIBe
2009-11-13 01:07:46
선리플 후감상이요!

저같은 경우엔 Third Try에 나오는 이야기에 복구 문제의 경우 저도 shortcoding을 하던 도중에 부딛혔던 문제였는데 쉽게 해결이 안ㅋ되ㅋ어서 좀 슬펐다능.

형의 멋진 글에 저의 소스코드를 선물로 복붙 하겠습니다-_-b
int M,N,a[4800];f(int n,int*c){int i,t[300];if(n>=M)return 0;for(i=0;i<N;t[i++]=c[i]);for(;--i>=0&2>(c[i]+=a[n*N+i]););if(i<0){for(;++i<N&c[i];);if(i==N|f(n+1,c))return 1;}for(i=0;i<N;c[i++]=t[i]);return f(n+1,c);}main(i){for(;~scanf("%d%d",&M,&N);){for(i=0;i<N*M;i++)scanf("%d",a+i);int c[300]={};puts(f(0,c)?"Yes, I found it":"It is impossible");}}

Leave a comment

春來不似春

About

Eventstream

Pages

Guestbook

Search

Site Admin

Recent Comments