All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
learn-piece.cc
Go to the documentation of this file.
1 /* learn-piece.cc
2  */
6 #include "osl/record/csaRecord.h"
7 #include "osl/record/ki2.h"
8 #include "osl/record/kakinoki.h"
9 #include "osl/record/kisen.h"
10 #include "osl/eval/see.h"
11 #include "osl/pieceStand.h"
12 #include <boost/algorithm/string/predicate.hpp>
13 #include <iostream>
14 using namespace osl;
15 using namespace std;
16 namespace csa=osl::record::csa;
17 CArray<int,PTYPE_SIZE> weight, gradient;
18 void show() {
19  for (size_t i=0; i<PieceStand::order.size(); ++i) {
20  Ptype ptype = PieceStand::order[i];
21  cout << csa::show(ptype) << ' ' << weight[ptype] << ' ';
22  if (canPromote(ptype))
23  cout << csa::show(promote(ptype)) << ' ' << weight[promote(ptype)] << ' ';
24  }
25  cout << endl;
26 #if 0
27  for (size_t i=0; i<PieceStand::order.size(); ++i) {
28  Ptype ptype = PieceStand::order[i];
29  cout << csa::show(ptype) << ' ' << gradient[ptype] << ' ';
30  if (canPromote(ptype))
31  cout << csa::show(promote(ptype)) << ' ' << gradient[promote(ptype)] << ' ';
32  }
33  cout << endl;
34 #endif
35 }
36 int median() {
37  osl::vector<int> copy;
38  for (int i=0; i<PTYPE_SIZE; ++i)
39  if (gradient[i]!=0) copy.push_back(gradient[i]);
40  sort(copy.begin(), copy.end());
41  if (copy.size() == 1) return 0;
42  if (copy.size()%2) return copy[copy.size()/2];
43  return copy[copy.size()/2]-1;
44 }
45 void update() {
46  std::vector<std::pair<int,Ptype> > gradient_ptype;
47  for (size_t i=0; i<PieceStand::order.size(); ++i) {
48  Ptype ptype = PieceStand::order[i];
49  gradient_ptype.push_back(std::make_pair(gradient[ptype], ptype));
50  if (canPromote(ptype)) {
51  ptype = promote(ptype);
52  gradient_ptype.push_back(std::make_pair(gradient[ptype], ptype));
53  }
54  }
55  std::sort(gradient_ptype.begin(), gradient_ptype.end());
56  // bonanza's robust update seems better than standard gradient descent methods, here
57  // const int a[13] = { -1, -1, -1, -1, -1, -1, 0, 1, 1, 1, 1, 1, 1 };
58  const int a[13] = { -3, -2, -2, -1, -1, -1, 0, 1, 1, 1, 2, 2, 3 };
59  for (size_t i=0; i<gradient_ptype.size(); ++i)
60  weight[gradient_ptype[i].second] += a[i];
61 }
62 void count(const NumEffectState& state, CArray<int,PTYPE_SIZE>& out) {
63  out.fill(0);
64  for (int i=0; i<Piece::SIZE; ++i) {
65  Piece p = state.pieceOf(i);
66  out[p.ptype()] += playerToSign(p.owner());
67  }
68 }
69 void compare(Player turn, const NumEffectState& selected,
70  const NumEffectState& not_selected) {
71  CArray<int,PTYPE_SIZE> c0, c1, diff;
72  count(selected, c0);
73  count(not_selected, c1);
74  int evaldiff = 0;
75  for (int i=0; i<PTYPE_SIZE; ++i) {
76  diff[i] = (c0[i] - c1[i])*playerToSign(turn);
77  evaldiff += diff[i] * weight[i];
78  }
79  if (evaldiff > 0) return;
80  for (int i=0; i<PTYPE_SIZE; ++i)
81  gradient[i] += diff[i];
82 }
83 Move greedymove(const NumEffectState& state) {
84  MoveVector all;
85  LegalMoves::generate(state, all);
86  int best_see = 0;
87  Move best_move;
88  for (size_t i=0; i<all.size(); ++i) {
89  if (! all[i].isCaptureOrPromotion()) continue;
90  int see = See::see(state, all[i]);
91  if (see <= best_see) continue;
92  best_see = see;
93  best_move = all[i];
94  }
95  return best_move;
96 }
97 void make_PV(const NumEffectState& src, Move prev, MoveVector& pv) {
98  NumEffectState state(src);
99  pv.clear();
100  // todo: quiescence search
101  while (true) {
102  state.makeMove(prev);
103  pv.push_back(prev);
104  Move move = greedymove(state);
105  if (! move.isNormal())
106  return;
107  prev = move;
108  }
109 }
110 void make_moves(NumEffectState& state, const MoveVector& pv) {
111  for (size_t i=0; i<pv.size(); ++i)
112  state.makeMove(pv[i]);
113 }
114 
115 void run(const osl::vector<Move>& moves) {
116  NumEffectState state;
117  for (size_t i=0; i<moves.size(); ++i) {
118  const Move selected = moves[i];
119  MoveVector all;
120  LegalMoves::generate(state, all);
121 
122  if (! state.hasEffectAt(alt(selected.player()), selected.to())) {
123  MoveVector pv0;
124  make_PV(state, selected, pv0);
125  NumEffectState s0(state);
126  make_moves(s0, pv0);
127  for (size_t j=0; j<all.size(); ++j)
128  if (all[j] != selected) {
129  MoveVector pv1;
130  make_PV(state, all[j], pv1);
131  NumEffectState s1(state);
132  make_moves(s1, pv1);
133  compare(state.turn(), s0, s1);
134  }
135  }
136  state.makeMove(selected);
137  }
138 }
139 int main(int argc, char **argv) {
140  weight.fill(500);
141  for (int t=0; t<1024; ++t) {
142  show();
143  gradient.fill(0);
144  for (int i=1; i<argc; ++i) {
145  const char *filename = argv[i];
146  if (boost::algorithm::iends_with(filename, ".csa")) {
147  const CsaFile csa(filename);
148  run(csa.getRecord().getMoves());
149  }
150  else if (boost::algorithm::iends_with(filename, ".ki2")) {
151  const Ki2File ki2(filename);
152  run(ki2.getRecord().getMoves());
153  }
154  else if (boost::algorithm::iends_with(filename, ".kif")
155  && KakinokiFile::isKakinokiFile(filename)) {
156  const KakinokiFile kif(filename);
157  run(kif.getRecord().getMoves());
158  }
159  else if (boost::algorithm::iends_with(filename, ".kif")) {
160  KisenFile kisen(filename);
161  for (size_t j=0; j<kisen.size(); ++j)
162  run(kisen.getMoves(j));
163  }
164  else {
165  cerr << "Unknown file type: " << filename << "\n";
166  continue;
167  }
168  }
169  update();
170  }
171 }
172 // ;;; Local Variables:
173 // ;;; mode:c++
174 // ;;; c-basic-offset:2
175 // ;;; coding:utf-8
176 // ;;; End: