Supercon 2011 ソース

(僕が開催者だったら、ジャッジ終わるまでの一週間はあんまり晒して欲しくないけど・・・)

とりあえず、晒していいものか分からないけど晒す。配列確保しすぎてダーク♂。

方針としては、

ただのダイクストラでスタートとゴールからの距離とか求めて、それに基づいて配列のオフセットとかいろいろ決める。

k-最短路はダークな方法で求める。

最初組んだksp遅かったので、spaghetti sourceのサイドトラック云々の手法を真似てコスト変更加えてみた。そしたら数秒以上速くなってビビった。

06/23

アルゴリズムの説明.txt加えた。

/* SuperCon 2011 予選問題C用テンプレート
   ・解答プログラムはこのテンプレートに従って作成すること.
  ・解答プログラムは1つのファイルで,チーム名.c という名前にすること.
  ・入力の方式は,あらかじめ入力ファイル(例:input_sample.txt)を作っ
   ておき,実行時にファイル名を指定する方式です.
*/

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

/* ↓以下の範囲は変更可能 */

#define W_MAX (201)
#define H_MAX (201)
#define N_MAX (8041)
#define Q_MAX (3200010)
#define max(a,b) (a>b?a:b)

int *memo[W_MAX][H_MAX];
int cost[W_MAX][H_MAX][4];
int begin[W_MAX][H_MAX];
int shortest[W_MAX][H_MAX];
int globalLength;

int dx[] = {0,1,0,-1};
int dy[] = {1,0,-1,0};

int qsize;
typedef struct {
  int key;
  int realcost;
  int x,y;
} QUE;
QUE que[Q_MAX+1];

#define PARENT(i) ((i)>>1)
#define LEFT(i)   ((i)<<1)
#define RIGHT(i)  (((i)<<1)+1)

void init(void)
{
  qsize = 0;
}

static void min_heapify(int i)
{
  int l, r;
  int smallest;

  l = LEFT(i), r = RIGHT(i);
  if (l < qsize && que[l].key < que[i].key) smallest = l; else smallest = i;
  if (r < qsize && que[r].key < que[smallest].key) smallest = r;
  if (smallest != i) {
    QUE t = que[i]; que[i] = que[smallest]; que[smallest] = t;
    min_heapify(smallest);
  }
}

int deq(QUE *q)
{
  if (qsize == 0) return 0;
  memcpy(q, &que[0], sizeof(QUE));
  que[0] = que[--qsize];
  min_heapify(0);
  return 1;
}

void enq(QUE *q)
{
  int i, ii;

  i = qsize++;
  memcpy(&que[i], q, sizeof(QUE));
  while (i > 0 && que[ii = PARENT(i)].key > que[i].key) {
    QUE t = que[i]; que[i] = que[ii]; que[ii] = t;
    i = ii;
  }
}


int ksp(int sx, int sy, int gx, int gy, int k) {
  int dist[W_MAX][H_MAX]={{0}};
  int prev[W_MAX][H_MAX];
  int i,key,cst,x,y;
  QUE temp;
  
  init();
  
  temp.key = 0 , temp.realcost = 0 , temp.x = sx , temp.y = sy ;  
  enq(&temp);
  while ( deq(&temp) != 0 ) {
	key = temp.key;
	cst = temp.realcost;
	x = temp.x ;
	y = temp.y ;
		
    if (dist[x][y] >= k || (dist[x][y] >= 1 && prev[x][y] == cst)) continue;
	else{
		dist[x][y]++ , prev[x][y] = cst;
	}
	
	if(dist[x][y] == k && x == gx && y == gy) return cst;
	
	for(i=0;i<4;i++){
		if(cost[x][y][i] != 0){
			temp.x = x+dx[i] , temp.y = y+dy[i];
			temp.key = key + (cost[x][y][i] - shortest[x][y] + shortest[temp.x][temp.y]);
			temp.realcost = cst + cost[x][y][i];
			enq(&temp);
		}
	}
  }
  return 0;
}

int dijkstra(int sx, int sy,int gx,int gy,int save_to[W_MAX][H_MAX]) {
  char done[W_MAX][H_MAX]={{0}};
  int i,c,x,y;
  QUE temp;
  init();  
  temp.key = 0 , temp.x = sx , temp.y = sy;  
  enq(&temp);
  while ( deq(&temp) != 0 ) {
	c = temp.key ;
	x = temp.x ;
	y = temp.y ;
		
    if(done[x][y]) continue;
	else save_to[x][y] = c , done[x][y] = 1;
	
	for(i=0;i<4;i++){
		if(cost[x][y][i] != 0){
			temp.key = c + cost[x][y][i] , temp.x = x+dx[i] , temp.y = y+dy[i];
			enq(&temp);
		}
	}
  }
  return (int)done[gx][gy];
}

int dfs(int x,int y,int c){
	if(c  > globalLength - shortest[x][y]) return 0;
	if(~memo[x][y][ c - begin[x][y] ])return memo[x][y][ c - begin[x][y] ];
	int d, ans = 0;	
	for(d=0;d<4;d++){
		if(cost[x][y][d]){
			ans += cost[x][y][d] ? dfs( x+dx[d] , y+dy[d] , c + cost[x][y][d]) : 0;
		}
	}
	
	return memo[x][y][ c - begin[x][y] ] = ans;
	
}

int total(int m,int n,int answer_length){
	int i,j;
	for(i=0;i<=m;i++)for(j=0;j<=n;j++){
		memo[i][j] = (int *)malloc( sizeof(int) * ( max(0,answer_length - shortest[i][j] - begin[i][j]+1) + 1) );
		memset(memo[i][j] , -1 ,    sizeof(int) * ( max(0,answer_length - shortest[i][j] - begin[i][j]+1) + 1) );
	}
	
	globalLength = answer_length;
	memo[m][n][answer_length - begin[m][n] ] = 1;
	return dfs(0,0,0);
}

/* ↑上記の範囲は変更可能 */

int main(int argc, char** argv) {
  int answer_length = -1; /* この変数に k 番目に長い経路の長さを代入してください */
  int answer_number = -1; /* この変数に k 番目に長い経路の総数を代入してください */
  int m, n, k;
  int D[200+1][200+1][4];
  char* problem_file;
  clock_t start, end;
  FILE* fp;

  int i, x, y;
  char buf[0xffff];
  char* p;

  if (argc <= 1) {
    fprintf(stderr, "Enter the input file.\n");
    exit(EXIT_FAILURE);
  }

  problem_file = argv[1];
  fp = fopen(problem_file, "r");
  if (fp == NULL) {
    fprintf(stderr, "Cannot open %s.\n", problem_file);
    exit(EXIT_FAILURE);
  }

  p = fgets(buf, 0xffff, fp);
  assert(p != 0);

  m = atoi(strtok(buf, ", \n"));
  n = atoi(strtok(NULL, ", \n"));
  k = atoi(strtok(NULL, ", \n"));
  assert(m > 0 && m <= 200);
  assert(n > 0 && n <= 200);
  assert(k > 0 && k <= 200);
  for (y = 0; y <= n; y++) {
    p = fgets(buf, 0xffff, fp);
    assert(p != 0);
    p = strtok(buf, ", \n");
    for (i = 0; i < m; i++) {
      int length = atoi(p);
      assert(length >= 0 && length <= 20);
      D[i][y][1] = length;
      D[i+1][y][3] = length;
      p = strtok(NULL, ", \n");
    }
    D[0][y][3] = 0;
    D[m][y][1] = 0;
  }
  for (x = 0; x <= m; x++) {
    p = fgets(buf, 0xffff, fp);
    assert(p != 0);
    p = strtok(buf, ", \n");
    for (i = 0; i < n; i++) {
      int length = atoi(p);
      assert(length >= 0 && length <= 20);
      D[x][i][0] = length;
      D[x][i+1][2] = length;
      p = strtok(NULL, ", \n");
    }
    D[x][0][2] = 0;
    D[x][n][0] = 0;
  }

  fclose(fp);

  /* 入力した情報を画面に出力する(デバッグ等に使って下さい)
   提出時は削除しないで,このようにコメントアウトすること
  printf("The input graph\n");
  for (i = 2*n; i >= 0; i--) {
    y = i / 2;
    if (i % 2 == 0) {
      printf("+");
      for (x = 0; x < m; x++) {
        assert(D[x][y][1] == D[x+1][y][3]);
        printf("%d+", D[x][y][1]);
      }
      assert(D[0][y][3] == 0);
      assert(D[m][y][1] == 0);
    } else if (i > 0) {
      for (x = 0; x <= m; x++) {
        assert(D[x][y][0] == D[x][y+1][2]);
        assert(y != n-1 || D[x][y+1][0] == 0);
        printf("%d ", D[x][y][0]);
      }
    } else {
      for (x = 0; x <= m; x++) {
        assert(D[x][y][2] == 0);
      }
    }
    printf("\n");
  }
  printf("\n");
  */

  /* 時間計測用(デバッグ等に使って下さい)
   提出時は削除しないで,このようにコメントアウトすること
  start = clock();
  */

/* ↓以下の範囲は変更可能 */
	memcpy(cost,D,sizeof(cost));
	
	if( dijkstra(m,n,0,0,shortest) ){
		dijkstra(0,0,m,n,begin);
		answer_length = ksp(0,0,m,n,k);
		answer_number = total(m,n,answer_length);
	}else{
		answer_length = 0;
		answer_number = 0;
	}
/* ↑上記の範囲は変更可能 */


  /* 時間計測用(デバッグ等に使って下さい)
   提出時は削除しないで,このようにコメントアウトすること
  end = clock();
  printf("%s, %f, %d, %d\n", problem_file, (double)(end - start) / CLOCKS_PER_SEC, answer_length, answer_number);
  */

  printf("%s, %d, %d\n", problem_file, answer_length, answer_number);
  return 0;
}

アルゴリズムの説明

■概要

本問題を解くにあたっては、k番目に短い経路長を求めるのには、ダイクストラ法を拡張したものを用い、経路総数を求めるのには、動的計画法を用いた。

動的計画法を用いた理由は、ナイーブな探索と比べると非常に高速に解を求めることが出来ると判断したからである。

実際の実装にあたっては、ループを用いた手法は無駄が多かったので、メモ化再帰と呼ばれる手法を用いて最適化を図った。

■プログラムの全体の流れ

(1) ダイクストラ法を用い、(m,n)地点から各地点への最短経路を予め計算しておく。また有効な経路が存在するかも同時に確認する。

(2) 再びダイクストラ法を用いて、(0,0)地点から各地点への最短経路を予め計算しておく。

(3) k番目に短いパスを探すために、ksp()関数を実行する。

(4) (1) , (2)で得られた結果を使い、格子点毎に必要なメモリ空間を、メモ化用配列に動的にそれぞれ割り当てる。

(4) 再帰探索により経路の総数を求める。枝狩りを用いて探索空間を減らしつつ、メモ化も用いることで探索効率を向上させた。""

■各アルゴリズムの詳細

・k番目に短い経路長を求めるために用いるアルゴリズム(ダイクストラを拡張したもの)

優先度付きキューを用いたため計算量は、O( k(|E|+|V|) log |V|)である。

本来、最短経路を求めるダイクストラ法では、一度確定した頂点に後からたどり着くような経路は全て切り捨てるが、

このアルゴリズムでは、各頂点についてk個の"ユニークな"経路が見つかるまで切り捨てないようにすることで、

(m,n)地点にk番目に短い経路長でたどり着く場合の長さを得れるようなアルゴリズムを組んだ。

[ (m,n)地点にk番目に短い経路長でたどり着くのが目的の場合にも、他の頂点への経路は高々k個確定したところで切り捨てて良いため。]

また、工夫としては、辺へのコストに少し特殊な重みを付けることで、優先的にゴールに近い頂点から確定するようにした。

そのため、大きいケースでは秒単位での高速化が計れた。

-

ダイクストラ

優先度付きキューを用いたため、計算量はO(|E| log |V|)である。

探索空間を減らすために使う。

また経路が1つも存在しない場合、無駄な探索を行ってしまうことが分かるので、それを回避するためにも実行する。

これを実行することで大きいケースにおいては手元で500ms程の改善が見られた。

しかし、小さいケースに対しては100ms程低速化した。

が、大きいケースに対しての恩恵が大きいので常に適用した。

-

・メモ化再帰探索

メモ化のための状態を、[現在のx座標][現在のy座標][たどり着くのに掛かったコスト]の三次元で持った。

計算量はO(mnk)程度だが、定数項が大きいためやや低速に動作する。

今の引数で計算した結果を記憶しておき、同じ引数で呼び出された時にその結果を返すことで、再計算を防いでいる。

-

・探索空間を減らす工夫について

最短経路だけでなく、k番目のパスを探す場合にも、各頂点について以下の事実が分かる。

・スタートから最短でたどり着ける場合のコストがsMinCostの時、sMinCost未満のための配列を用意する必要がないことは明らかである。

・また、ゴールから最短でたどり着ける場合のコストがgMinCostの時、目標の距離にgMinCostより多く掛かる場合、無駄な探索をしてしまうため、そのための配列を用意する必要はない。

以上のことを踏まえると、各頂点について、(gMinCost - sMinCost + 1)分のメモリだけ確保しておけば良いことが分かる。

そして、k≦200なので、最悪の場合でも(gMinCost - sMinCost + 1)は8000程度であり、少々多いが現実的なメモリ使用量に収まる。

■欠点

・メモ化再帰を用いているため、環境によってはスタックオーバーフローを引き起こす可能性がある。

・また、動的確保を用いているため、稀にかなり低速になることがある。(12sで終わるものが2分掛かったりした)