SHOGUN  v3.2.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
StructuredModel.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Thoralf Klein
8  * Written (W) 2012 Fernando José Iglesias García
9  * Copyright (C) 2012 Fernando José Iglesias García
10  */
11 
13 
14 using namespace shogun;
15 
16 CResultSet::CResultSet() : CSGObject(), argmax(NULL)
17 {
18 }
19 
21 {
23 }
24 
26 {
27  return new CStructuredLabels(num_labels);
28 }
29 
30 const char* CResultSet::get_name() const
31 {
32  return "ResultSet";
33 }
34 
36 {
37  init();
38 }
39 
41  CFeatures* features,
42  CStructuredLabels* labels)
43 : CSGObject()
44 {
45  init();
46 
47  set_labels(labels);
48  set_features(features);
49 }
50 
52 {
55 }
56 
58  float64_t regularization,
66 {
67  SG_ERROR("init_primal_opt is not implemented for %s!\n", get_name())
68 }
69 
71 {
72  SG_REF(labels);
74  m_labels = labels;
75 }
76 
78 {
80  return m_labels;
81 }
82 
84 {
85  SG_REF(features);
87  m_features = features;
88 }
89 
91 {
93  return m_features;
94 }
95 
97  int32_t feat_idx,
98  int32_t lab_idx)
99 {
100  CStructuredData* label = m_labels->get_label(lab_idx);
101  SGVector< float64_t > ret = get_joint_feature_vector(feat_idx, label);
102  SG_UNREF(label);
103 
104  return ret;
105 }
106 
108  int32_t feat_idx,
109  CStructuredData* y)
110 {
111  SG_ERROR("compute_joint_feature(int32_t, CStructuredData*) is not "
112  "implemented for %s!\n", get_name());
113 
114  return SGVector< float64_t >();
115 }
116 
118 {
119  REQUIRE(ytrue_idx >= 0 || ytrue_idx < m_labels->get_num_labels(),
120  "The label index must be inside [0, num_labels-1]\n");
121 
122  CStructuredData* ytrue = m_labels->get_label(ytrue_idx);
123  float64_t ret = delta_loss(ytrue, ypred);
124  SG_UNREF(ytrue);
125 
126  return ret;
127 }
128 
130 {
131  SG_ERROR("delta_loss(CStructuredData*, CStructuredData*) is not "
132  "implemented for %s!\n", get_name());
133 
134  return 0.0;
135 }
136 
137 void CStructuredModel::init()
138 {
139  SG_ADD((CSGObject**) &m_labels, "m_labels", "Structured labels",
141  SG_ADD((CSGObject**) &m_features, "m_features", "Feature vectors",
143 
144  m_features = NULL;
145  m_labels = NULL;
146 }
147 
149 {
150  // Nothing to do here
151 }
152 
154 {
155  // Nothing to do here
156  return true;
157 }
158 
160 {
161  return 0;
162 }
163 
165 {
166  return 0;
167 }
Base class of the labels used in Structured Output (SO) problems.
void set_labels(CStructuredLabels *labs)
#define SG_UNREF(x)
Definition: SGRefObject.h:35
#define SG_ERROR(...)
Definition: SGIO.h:131
#define REQUIRE(x,...)
Definition: SGIO.h:208
SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, int32_t lab_idx)
CStructuredLabels * get_labels()
virtual int32_t get_num_aux_con() const
virtual int32_t get_num_aux() const
void set_features(CFeatures *feats)
virtual const char * get_name() const
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:102
virtual CStructuredLabels * structured_labels_factory(int32_t num_labels=0)
double float64_t
Definition: common.h:48
#define SG_REF(x)
Definition: SGRefObject.h:34
float64_t delta_loss(int32_t ytrue_idx, CStructuredData *ypred)
virtual bool check_training_setup() const
CStructuredLabels * m_labels
The class Features is the base class of all feature objects.
Definition: Features.h:62
CStructuredData * argmax
virtual CStructuredData * get_label(int32_t idx)
#define SG_ADD(...)
Definition: SGObject.h:71
virtual const char * get_name() const
Base class of the components of StructuredLabels.
virtual void init_primal_opt(float64_t regularization, SGMatrix< float64_t > &A, SGVector< float64_t > a, SGMatrix< float64_t > B, SGVector< float64_t > &b, SGVector< float64_t > lb, SGVector< float64_t > ub, SGMatrix< float64_t > &C)

SHOGUN Machine Learning Toolbox - Documentation