1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 package org.dbunit.ant;
22
23 import java.io.File;
24 import java.io.IOException;
25 import java.net.MalformedURLException;
26 import java.sql.SQLException;
27 import java.util.ArrayList;
28 import java.util.Iterator;
29 import java.util.List;
30
31 import org.apache.tools.ant.ProjectComponent;
32 import org.dbunit.DatabaseUnitException;
33 import org.dbunit.database.DatabaseConfig;
34 import org.dbunit.database.IDatabaseConnection;
35 import org.dbunit.database.QueryDataSet;
36 import org.dbunit.dataset.CachedDataSet;
37 import org.dbunit.dataset.CompositeDataSet;
38 import org.dbunit.dataset.DataSetException;
39 import org.dbunit.dataset.ForwardOnlyDataSet;
40 import org.dbunit.dataset.IDataSet;
41 import org.dbunit.dataset.csv.CsvProducer;
42 import org.dbunit.dataset.excel.XlsDataSet;
43 import org.dbunit.dataset.stream.IDataSetProducer;
44 import org.dbunit.dataset.stream.StreamingDataSet;
45 import org.dbunit.dataset.xml.FlatDtdProducer;
46 import org.dbunit.dataset.xml.FlatXmlProducer;
47 import org.dbunit.dataset.xml.XmlProducer;
48 import org.dbunit.dataset.yaml.YamlProducer;
49 import org.dbunit.util.FileHelper;
50 import org.slf4j.Logger;
51 import org.slf4j.LoggerFactory;
52 import org.xml.sax.InputSource;
53
54
55
56
57
58
59
60 public abstract class AbstractStep extends ProjectComponent implements DbUnitTaskStep
61 {
62
63
64
65
66 private static final Logger logger = LoggerFactory.getLogger(AbstractStep.class);
67
68 public static final String FORMAT_FLAT = "flat";
69 public static final String FORMAT_XML = "xml";
70 public static final String FORMAT_DTD = "dtd";
71 public static final String FORMAT_CSV = "csv";
72 public static final String FORMAT_XLS = "xls";
73 public static final String FORMAT_YML = "yml";
74
75 private boolean ordered = false;
76
77
78 protected IDataSet getDatabaseDataSet(IDatabaseConnection connection,
79 List tables) throws DatabaseUnitException
80 {
81 if (logger.isDebugEnabled())
82 {
83 logger.debug("getDatabaseDataSet(connection={}, tables={}) - start",
84 connection, tables);
85 }
86
87 try
88 {
89 DatabaseConfig config = connection.getConfig();
90
91
92 if (tables.size() == 0)
93 {
94 logger.debug("Retrieving the whole database because tables/queries have not been specified");
95 return connection.createDataSet();
96 }
97
98 List queryDataSets = createQueryDataSet(tables, connection);
99
100 IDataSet[] dataSetsArray = null;
101 if (config.getProperty(DatabaseConfig.PROPERTY_RESULTSET_TABLE_FACTORY)
102 .getClass().getName().equals("org.dbunit.database.ForwardOnlyResultSetTableFactory")) {
103 dataSetsArray = createForwardOnlyDataSetArray(queryDataSets);
104 } else {
105 dataSetsArray = (IDataSet[]) queryDataSets.toArray(new IDataSet[queryDataSets.size()]);
106 }
107 return new CompositeDataSet(dataSetsArray);
108 }
109 catch (SQLException e)
110 {
111 throw new DatabaseUnitException(e);
112 }
113 }
114
115
116 private ForwardOnlyDataSet[] createForwardOnlyDataSetArray(List<QueryDataSet> dataSets) throws DataSetException, SQLException {
117 ForwardOnlyDataSet[] forwardOnlyDataSets = new ForwardOnlyDataSet[dataSets.size()];
118
119 for (int i = 0; i < dataSets.size(); i++) {
120 forwardOnlyDataSets[i] = new ForwardOnlyDataSet(dataSets.get(i));
121 }
122
123 return forwardOnlyDataSets;
124 }
125
126 private List createQueryDataSet(List tables, IDatabaseConnection connection)
127 throws DataSetException, SQLException
128 {
129 logger.debug("createQueryDataSet(tables={}, connection={})", tables, connection);
130
131 List queryDataSets = new ArrayList();
132
133 QueryDataSet queryDataSet = new QueryDataSet(connection);
134
135 for (Iterator it = tables.iterator(); it.hasNext();)
136 {
137 Object item = it.next();
138
139 if(item instanceof QuerySet) {
140 if(queryDataSet.getTableNames().length > 0)
141 queryDataSets.add(queryDataSet);
142
143 QueryDataSet newQueryDataSet = (((QuerySet)item).getQueryDataSet(connection));
144 queryDataSets.add(newQueryDataSet);
145 queryDataSet = new QueryDataSet(connection);
146 }
147 else if (item instanceof Query)
148 {
149 Query queryItem = (Query)item;
150 queryDataSet.addTable(queryItem.getName(), queryItem.getSql());
151 }
152 else if (item instanceof Table)
153 {
154 Table tableItem = (Table)item;
155 queryDataSet.addTable(tableItem.getName());
156 }
157 else
158 {
159 throw new IllegalArgumentException("Unsupported element type " + item.getClass().getName() + ".");
160 }
161 }
162
163 if(queryDataSet.getTableNames().length > 0)
164 queryDataSets.add(queryDataSet);
165
166 return queryDataSets;
167 }
168
169
170 protected IDataSet getSrcDataSet(File src, String format,
171 boolean forwardonly) throws DatabaseUnitException
172 {
173 if (logger.isDebugEnabled())
174 {
175 logger.debug("getSrcDataSet(src={}, format={}, forwardonly={}) - start",
176 src, format, forwardonly);
177 }
178
179 try
180 {
181 IDataSetProducer producer = null;
182 if (format.equalsIgnoreCase(FORMAT_XML))
183 {
184 producer = new XmlProducer(getInputSource(src));
185 }
186 else if (format.equalsIgnoreCase(FORMAT_CSV))
187 {
188 producer = new CsvProducer(src);
189 }
190 else if (format.equalsIgnoreCase(FORMAT_FLAT))
191 {
192 producer = new FlatXmlProducer(getInputSource(src), true, true);
193 }
194 else if (format.equalsIgnoreCase(FORMAT_DTD))
195 {
196 producer = new FlatDtdProducer(getInputSource(src));
197 }
198 else if (format.equalsIgnoreCase(FORMAT_XLS))
199 {
200 return new CachedDataSet(new XlsDataSet(src));
201 }
202 else if (format.equalsIgnoreCase(FORMAT_YML))
203 {
204 return new CachedDataSet(new YamlProducer(src), true);
205 }
206 else
207 {
208 throw new IllegalArgumentException("Type must be either 'flat'(default), 'xml', 'csv', 'xls', 'yml' or 'dtd' but was: " + format);
209 }
210
211 if (forwardonly)
212 {
213 return new StreamingDataSet(producer);
214 }
215 return new CachedDataSet(producer);
216 }
217 catch (IOException e)
218 {
219 throw new DatabaseUnitException(e);
220 }
221 }
222
223
224
225
226
227
228
229
230
231
232
233 public boolean isDataFormat(String format)
234 {
235 logger.debug("isDataFormat(format={}) - start", format);
236
237 return format.equalsIgnoreCase(FORMAT_FLAT)
238 || format.equalsIgnoreCase(FORMAT_XML)
239 || format.equalsIgnoreCase(FORMAT_CSV)
240 || format.equalsIgnoreCase(FORMAT_XLS)
241 || format.equalsIgnoreCase(FORMAT_YML);
242 }
243
244
245
246
247
248
249
250
251
252
253 protected void checkDataFormat(String format)
254 {
255 logger.debug("checkDataFormat(format={}) - start", format);
256
257 if (!isDataFormat(format))
258 {
259 throw new IllegalArgumentException("format must be either 'flat'(default), 'xml', 'csv', 'xls' or 'yml' but was: " + format);
260 }
261 }
262
263
264
265
266
267
268
269
270 public static InputSource getInputSource(File file) throws MalformedURLException
271 {
272 InputSource source = FileHelper.createInputSource(file);
273 return source;
274 }
275
276 public boolean isOrdered()
277 {
278 return ordered;
279 }
280
281 public void setOrdered(boolean ordered)
282 {
283 this.ordered = ordered;
284 }
285
286 public String toString()
287 {
288 final StringBuilder result = new StringBuilder();
289 result.append("AbstractStep: ");
290 result.append("ordered=").append(this.ordered);
291 return result.toString();
292 }
293
294 }