44 #include "EST_cutils.h"
45 #include "EST_Token.h"
46 #include "EST_Wagon.h"
53 return impurity.value();
54 else if (question.ask(d))
55 return left->predict(d);
57 return right->predict(d);
64 else if (question.ask(d))
65 return left->predict_node(d);
67 return right->predict_node(d);
74 if ((left == 0) && (right == 0))
76 else if (get_impurity().type() != wnim_class)
82 void WNode::prune(
void)
90 if (left != 0) left->prune();
91 if (right != 0) right->prune();
95 if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
96 (left->get_impurity().value() == right->get_impurity().value()))
98 delete left; left = 0;
99 delete right; right = 0;
105 void WNode::held_out_prune()
116 wgn_score_question(question,get_data());
117 if (question.get_score() < get_impurity().measure())
119 wgn_find_split(question,get_data(),
122 left->held_out_prune();
123 right->held_out_prune();
127 delete left; left = 0;
128 delete right; right = 0;
133 void WNode::print_out(ostream &s,
int margin)
138 for (i=0;i<margin;i++) s <<
" ";
145 left->print_out(s,margin+1);
146 right->print_out(s,margin+1);
151 ostream & operator <<(ostream &s,
WNode &n)
160 void WDataSet::ignore_non_numbers()
165 for (i=0; i<dlength; i++)
167 if ((p_type[i] == wndt_binary) ||
168 (p_type[i] == wndt_float))
179 void WDataSet::load_description(
const EST_String &fname, LISP ignores)
186 description = car(vload(fname,1));
187 dlength = siod_llength(description);
193 if (wgn_predictee_name ==
"")
198 for (i=0,d=description; d != NIL; d=cdr(d),i++)
200 p_name[i] = get_c_string(car(car(d)));
201 tname = get_c_string(car(cdr(car(d))));
203 if ((wgn_predictee_name !=
"") && (wgn_predictee_name == p_name[i]))
205 if ((wgn_count_field_name !=
"") &&
206 (wgn_count_field_name == p_name[i]))
208 if ((tname ==
"count") || (i == wgn_count_field))
211 p_type[i] = wndt_ignore;
215 else if ((tname ==
"ignore") || (siod_member_str(p_name[i],ignores)))
217 p_type[i] = wndt_ignore;
219 if (i == wgn_predictee)
220 wagon_error(
EST_String(
"predictee \"")+p_name[i]+
221 "\" can't be ignored \n");
223 else if (siod_llength(car(d)) > 2)
225 LISP rest = cdr(car(d));
227 siod_list_to_strlist(rest,sl);
228 p_type[i] = wgn_discretes.def(sl);
229 if (streq(get_c_string(car(rest)),
"_other_"))
230 wgn_discretes[p_type[i]].def_val(
"_other_");
232 else if (tname ==
"binary")
233 p_type[i] = wndt_binary;
234 else if (tname ==
"cluster")
235 p_type[i] = wndt_cluster;
236 else if (tname ==
"vector")
237 p_type[i] = wndt_vector;
238 else if (tname ==
"trajectory")
239 p_type[i] = wndt_trajectory;
240 else if (tname ==
"matrix")
241 p_type[i] = wndt_matrix;
242 else if (tname ==
"float")
243 p_type[i] = wndt_float;
246 wagon_error(
EST_String(
"Unknown type \"")+tname+
248 "/"+p_name[i]+
" in description file \""+fname+
"\"");
252 if (wgn_predictee == -1)
254 wagon_error(
EST_String(
"predictee field \"")+wgn_predictee_name+
255 "\" not found in description ");
259 const int WQuestion::ask(
const WVector &w)
const
265 if (w.get_flt_val(feature_pos) == operand1.
Float())
270 if (w.get_int_val(feature_pos) == 1)
274 case wnop_greaterthan:
275 if (w.get_flt_val(feature_pos) > operand1.
Float())
280 if (w.get_flt_val(feature_pos) < operand1.
Float())
285 if (w.get_int_val(feature_pos) == operand1.
Int())
290 if (ilist_member(operandl,w.get_int_val(feature_pos)))
295 wagon_error(
"Unknown test operator");
301 ostream& operator<<(ostream& s,
const WQuestion &q)
304 static EST_Regex needquotes(
".*[()'\";., \t\n\r].*");
306 s <<
"(" << wgn_dataset.feat_name(q.get_fp());
310 s <<
" = " << q.get_operand1().
string();
314 case wnop_greaterthan:
315 s <<
" > " << q.get_operand1().
Float();
318 s <<
" < " << q.get_operand1().
Float();
321 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
322 name(q.get_operand1().
Int());
325 s << quote_string(name,
"\"",
"\\",1);
330 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
331 name(q.get_operand1().
Int());
332 s <<
" matches " << quote_string(name,
"\"",
"\\",1);
336 for (
int l=0; l < q.get_operandl().length(); l++)
338 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
339 name(q.get_operandl().
nth(l));
340 if (name.matches(needquotes))
341 s << quote_string(name,
"\"",
"\\",1);
365 cerr <<
"WImpurity: no value currently set\n";
368 else if (t==wnim_class)
370 else if (t==wnim_cluster)
372 else if (t==wnim_vector)
374 else if (t==wnim_trajectory)
380 double WImpurity::samples(
void)
384 else if (t==wnim_class)
386 else if (t==wnim_cluster)
387 return members.length();
388 else if (t==wnim_vector)
389 return members.length();
390 else if (t==wnim_trajectory)
391 return members.length();
401 a.
reset(); trajectory=0; l=0; width=0;
402 for (i=0; i < ds.
n(); i++)
404 if (wgn_count_field == -1)
405 cumulate((*(ds(i)))[wgn_predictee],1);
407 cumulate((*(ds(i)))[wgn_predictee],
408 (*(ds(i)))[wgn_count_field]);
412 float WImpurity::measure(
void)
416 else if (t == wnim_vector)
417 return vector_impurity();
418 else if (t == wnim_trajectory)
419 return trajectory_impurity();
420 else if (t == wnim_matrix)
422 else if (t == wnim_class)
424 else if (t == wnim_cluster)
425 return cluster_impurity();
428 cerr <<
"WImpurity: can't measure unset object" << endl;
433 float WImpurity::vector_impurity()
448 if (wgn_VertexFeats.
a(0,j) > 0.0)
451 for (pp=members.head(); pp != 0; pp=pp->next())
453 i = members.
item(pp);
454 b += wgn_VertexTrack.
a(i,j);
473 if (wgn_VertexFeats.
a(0,j) > 0.0)
475 for (pp=members.head(); pp != 0; pp=pp->next())
476 cs[j][j] += wgn_VertexTrack.
a(members.
item(pp),j);
482 if (wgn_VertexFeats.
a(0,j) > 0.0)
484 for (pp=members.head(); pp != 0; pp=pp->next())
486 mmm = members.
item(pp);
487 cs[i][j] += (wgn_VertexTrack.
a(mmm,i)-cs[j][j].
mean())*
488 (wgn_VertexTrack.
a(mmm,j)-cs[j][j].
mean());
495 if (wgn_VertexFeats.
a(0,j) > 0.0)
507 for (pp=members.head(); pp != 0; pp=pp->next())
509 x = members.
item(pp);
511 for (qq=pp->next(); qq != 0; qq=qq->next())
513 y = members.
item(qq);
515 if (wgn_VertexFeats.
a(0,j) > 0.0)
517 d = wgn_VertexTrack(x,j)-wgn_VertexTrack(y,j);
527 return a.
mean() * count;
530 WImpurity::~WImpurity()
537 delete [] trajectory[j];
538 delete [] trajectory;
545 float WImpurity::trajectory_impurity()
557 double n, m, m1, m2, w;
570 for (pp=members.head(); pp != 0; pp=pp->next())
572 i = members.
item(pp);
573 for (q=0; q<wgn_UnitTrack.
a(i,1); q++)
575 ni = (int)wgn_UnitTrack.
a(i,0)+q;
576 if (wgn_VertexTrack.
a(ni,0) == -1.0)
583 if (q==wgn_UnitTrack.
a(i,1))
589 l2ss += wgn_UnitTrack.
a(i,1) - (q+1) - 1;
590 lss += wgn_UnitTrack.
a(i,1);
591 if (wgn_UnitTrack.
a(i,1) > l)
592 l = (
int)wgn_UnitTrack.
a(i,1);
597 l = ((int)lss.
mean() < 7) ? 7 : (
int)lss.
mean();
605 for (pp=members.head(); pp != 0; pp=pp->next())
607 i = members.
item(pp);
608 m = (float)wgn_UnitTrack.
a(i,1)/(float)l;
609 s = (int)wgn_UnitTrack.
a(i,0);
610 for (ti=0,n=0.0; ti<l; ti++,n+=m)
615 if (wgn_VertexFeats.
a(0,j) > 0.0)
616 trajectory[ti][j] += wgn_VertexTrack.
a(s+ni,j);
623 for (ti=0; ti<l; ti++)
626 if (wgn_VertexFeats.
a(0,j) > 0.0)
627 stdss += trajectory[ti][j].
stddev();
631 score = stdss.
mean() * members.length();
635 l1 = (l1ss.
mean() < 10.0) ? 10 : (int)l1ss.
mean();
636 l2 = (l2ss.
mean() < 10.0) ? 10 : (int)l2ss.
mean();
644 for (pp=members.head(); pp != 0; pp=pp->next())
646 i = members.
item(pp);
648 s = (int)wgn_UnitTrack.
a(i,0);
649 for (q=0; q<wgn_UnitTrack.
a(i,1); q++)
650 if (wgn_VertexTrack.
a(s+q,0) == -1.0)
655 s2l = (int)wgn_UnitTrack.
a(i,1) - (s1l + 2);
656 m1 = (float)(s1l)/(float)l1;
657 m2 = (float)(s2l)/(float)l2;
659 for (ti=0,n=0.0; s1l > 0 && ti<l1; ti++,n+=m1)
661 ni = s + (((int)n < s1l) ? (int)n : s1l - 1);
663 if (wgn_VertexFeats.
a(0,j) > 0.0)
664 trajectory[ti][j] += wgn_VertexTrack.
a(ni,j);
668 if (wgn_VertexFeats.
a(0,j) > 0.0)
669 trajectory[ti][j] += -1;
672 for (ti++,n=0.0; s2l > 0 && ti<l-1; ti++,n+=m2)
674 ni = s + (((int)n < s2l) ? (int)n : s2l - 1);
676 if (wgn_VertexFeats.
a(0,j) > 0.0)
677 trajectory[ti][j] += wgn_VertexTrack.
a(ni,j);
680 if (wgn_VertexFeats.
a(0,j) > 0.0)
681 trajectory[ti][j] += -2;
688 for (w=0.0,ti=0; ti<l1; ti++,w+=m)
690 if (wgn_VertexFeats.
a(0,j) > 0.0)
691 stdss += trajectory[ti][j].
stddev() * w;
693 for (w=1.0,ti++; ti<l-1; ti++,w-=m)
695 if (wgn_VertexFeats.
a(0,j) > 0.0)
696 stdss += trajectory[ti][j].
stddev() * w;
699 score = stdss.
mean() * members.length();
704 float WImpurity::cluster_impurity()
715 for (pp=members.head(); pp != 0; pp=pp->next())
717 i = members.
item(pp);
718 for (q=pp->next(); q != 0; q=q->next())
721 dist = (j < i ? wgn_DistMatrix.
a_no_check(i,j) :
735 float WImpurity::cluster_distance(
int i)
739 float dist = cluster_member_mean(i);
740 float mdist = dist-a.
mean();
749 int WImpurity::in_cluster(
int i)
753 float dist = cluster_member_mean(i);
756 for (pp=members.head(); pp != 0; pp=pp->next())
758 if (dist < cluster_member_mean(members.
item(pp)))
764 float WImpurity::cluster_ranking(
int i)
767 float dist = cluster_distance(i);
771 for (pp=members.head(); pp != 0; pp=pp->next())
773 if (dist >= cluster_distance(members.
item(pp)))
780 float WImpurity::cluster_member_mean(
int i)
788 for (sum=0.0,n=0,q=members.head(); q != 0; q=q->next())
793 dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
799 return ( n == 0 ? 0.0 : sum/n );
802 void WImpurity::cumulate(
const float pv,
double count)
806 if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
811 else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
816 else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
821 else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
824 p.
init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
828 else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
831 a.cumulate((
int)pv,count);
833 else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
836 a.cumulate(pv,count);
840 wagon_error(
"WImpurity: cannot cumulate EST_Val type");
844 ostream & operator <<(ostream &s,
WImpurity &imp)
849 if (imp.t == wnim_float)
850 s <<
"(" << imp.a.
stddev() <<
" " << imp.a.
mean() <<
")";
851 else if (imp.t == wnim_vector)
855 imp.vector_impurity();
856 if (wgn_vertex_output ==
"mean")
861 for (p=imp.members.head(); p != 0; p=p->next())
863 b += wgn_VertexTrack.
a(imp.members.
item(p),j);
865 s <<
"(" << b.
mean() <<
" " << b.
stddev() <<
")";
873 double best = WGN_HUGE_VAL;
881 if (wgn_VertexFeats.
a(0,j) > 0.0)
884 for (p=imp.members.head(); p != 0; p=p->next())
886 cs[j] += wgn_VertexTrack.
a(imp.members.
item(p),j);
890 for (p=imp.members.head(); p != 0; p=p->next())
893 if (wgn_VertexFeats.
a(0,j) > 0.0)
895 d = (wgn_VertexTrack.
a(imp.members.
item(p),j)-cs[j].mean())
901 bestp = imp.members.
item(p);
908 s << wgn_VertexTrack.
a(bestp,j);
911 if (finite(cs[j].stddev()))
923 s << imp.a.
mean() <<
")";
925 else if (imp.t == wnim_trajectory)
928 imp.trajectory_impurity();
929 for (i=0; i<imp.l; i++)
934 s <<
"(" << imp.trajectory[i][j].
mean() <<
" "
935 << imp.trajectory[i][j].
stddev() <<
" " <<
")";
941 s << imp.a.
mean() <<
")";
943 else if (imp.t == wnim_cluster)
947 for (p=imp.members.head(); p != 0; p=p->next())
950 s <<
"(" << imp.members.
item(p) <<
" " <<
951 imp.cluster_member_mean(imp.members.
item(p)) <<
")";
957 s << imp.a.
mean() <<
")";
959 else if (imp.t == wnim_class)
969 s <<
"(" << name <<
" " << prob <<
") ";
974 s <<
"([WImpurity unset])";
INLINE const T & a_no_check(int row, int col) const
const access with no bounds check, care recommend
EST_Litem * item_next(EST_Litem *idx) const
Used for iterating through members of the distribution.
const EST_String & most_probable(double *prob=NULL) const
Return the most probable member of the distribution.
double stddev(void) const
standard deviation of currently cummulated values
float & a(int i, int c=0)
double samples(void) const
Total number of example found.
const int Int(void) const
A Regular expression class to go with the CSTR EST_String class.
bool init(const EST_StrList &vocab)
Initialise using given vocabulary.
int num_channels() const
return number of channels in track
double mean(void) const
mean of currently cummulated values
EST_String itoString(int n)
Make a EST_String object from an integer.
const float Float(void) const
T & nth(int n)
return the Nth value
EST_Litem * item_start() const
Used for iterating through members of the distribution.
int item_end(EST_Litem *idx) const
Used for iterating through members of the distribution.
void resize(int n, int set=1)
void cumulate(const EST_String &s, double count=1)
Add this observation, may specify number of occurrences.
double entropy(void) const
void item_prob(EST_Litem *idx, EST_String &s, double &prob) const
During iteration returns name and probability given index.
double variance(void) const
variance of currently cummulated values
const EST_String & string(void) const
INLINE int n() const
number of items in vector.
void append(const T &item)
add item onto end of list
void reset(void)
reset internal values
double samples(void)
number of samples in set
T & item(const EST_Litem *p)
int matches(const char *e, int pos=0) const
Exactly match this string?
void resize(int n, int set=1)
resize vector