1
2
3
4
5
6
7
8 __author__ = "Ian Haywood"
9 __license__ = "GPL v2 or later (details at http://www.gnu.org)"
10
11
12 import sys
13 import os
14 import re
15 import logging
16 import io
17
18
19 _log = logging.getLogger('gm.bootstrapper')
20
21 unformattable_error_id = 12345
22
23
25
27 """
28 db : the interpreter to connect to, must be a DBAPI compliant interface
29 """
30 self.conn = conn
31 self.vars = {'ON_ERROR_STOP': None}
32
33
34 - def match(self, pattern):
35 match = re.match(pattern, self.line)
36 if match is None:
37 return 0
38
39 self.groups = match.groups()
40 return 1
41
42
57
58
59 - def run (self, filename):
60 """
61 filename: a file, containg semicolon-separated SQL commands
62 """
63 _log.debug('processing [%s]', filename)
64 curs = self.conn.cursor()
65 curs.execute('show session authorization')
66 start_auth = curs.fetchall()[0][0]
67 curs.close()
68 _log.debug('session auth: %s', start_auth)
69
70 if os.access (filename, os.R_OK):
71 sql_file = io.open(filename, mode = 'rt', encoding = 'utf8')
72 else:
73 _log.error("cannot open file [%s]", filename)
74 return 1
75
76 self.lineno = 0
77 self.filename = filename
78 in_string = False
79 bracketlevel = 0
80 curr_cmd = ''
81 curs = self.conn.cursor()
82
83 for self.line in sql_file:
84 self.lineno += 1
85 if len(self.line.strip()) == 0:
86 continue
87
88
89 if self.match(r"^\\set (\S+) (\S+)"):
90 _log.debug('"\set" found: %s', self.groups)
91 self.vars[self.groups[0]] = self.groups[1]
92 if self.groups[0] == 'ON_ERROR_STOP':
93
94 self.vars['ON_ERROR_STOP'] = int(self.vars['ON_ERROR_STOP'])
95 continue
96
97
98 if self.match (r"^\\unset (\S+)"):
99 self.vars[self.groups[0]] = None
100 continue
101
102
103 if self.match (r"^\\(.*)") and not in_string:
104
105
106 _log.warning(self.fmt_msg("psql command \"\\%s\" being ignored " % self.groups[0]))
107 continue
108
109
110 this_char = self.line[0]
111
112 for next_char in self.line[1:] + ' ':
113
114
115 if this_char == "'":
116 in_string = not in_string
117
118
119 if this_char == '-' and next_char == '-' and not in_string:
120 break
121
122
123 if this_char == '(' and not in_string:
124 bracketlevel += 1
125 if this_char == ')' and not in_string:
126 bracketlevel -= 1
127
128
129
130
131
132 if not ((in_string is False) and (bracketlevel == 0) and (this_char == ';')):
133 curr_cmd += this_char
134 else:
135 if curr_cmd.strip() != '':
136 try:
137 curs.execute(curr_cmd)
138 try:
139 data = curs.fetchall()
140 _log.debug('cursor data: %s', data)
141 except Exception:
142 pass
143 except Exception as error:
144 _log.exception(curr_cmd)
145 if re.match(r"^NOTICE:.*", str(error)):
146 _log.warning(self.fmt_msg(error))
147 else:
148 _log.error(self.fmt_msg(error))
149 if hasattr(error, 'diag'):
150 for prop in dir(error.diag):
151 if prop.startswith('__'):
152 continue
153 val = getattr(error.diag, prop)
154 if val is None:
155 continue
156 _log.error('PG diags %s: %s', prop, val)
157 if self.vars['ON_ERROR_STOP']:
158 self.conn.commit()
159 curs.close()
160 return 1
161
162 self.conn.commit()
163 curs.close()
164 curs = self.conn.cursor()
165 curr_cmd = ''
166
167 this_char = next_char
168
169
170
171 self.conn.commit()
172 curs.execute('show session authorization')
173 end_auth = curs.fetchall()[0][0]
174 curs.close()
175 _log.debug('session auth after sql file processing: %s', end_auth)
176 if start_auth != end_auth:
177 _log.error('session auth changed before/after processing sql file')
178
179 return 0
180
181
182
183 if __name__ == '__main__':
184
185 if len(sys.argv) < 2:
186 sys.exit()
187
188 if sys.argv[1] != 'test':
189 sys.exit()
190
191
192 conn = PgSQL.connect(user='gm-dbo', database = 'gnumed')
193 psql = Psql(conn)
194 psql.run(sys.argv[1])
195 conn.close()
196