1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 package org.dbunit.database;
22
23 import java.sql.PreparedStatement;
24 import java.sql.ResultSet;
25 import java.sql.SQLException;
26 import java.util.*;
27
28 import org.dbunit.database.search.ForeignKeyRelationshipEdge;
29 import org.dbunit.dataset.DataSetException;
30 import org.dbunit.dataset.IDataSet;
31 import org.dbunit.dataset.ITable;
32 import org.dbunit.dataset.ITableIterator;
33 import org.dbunit.dataset.ITableMetaData;
34 import org.dbunit.dataset.filter.AbstractTableFilter;
35 import org.dbunit.util.SQLHelper;
36 import org.slf4j.Logger;
37 import org.slf4j.LoggerFactory;
38
39
40
41
42
43
44
45
46
47
48
49
50
51 public class PrimaryKeyFilter extends AbstractTableFilter {
52
53 private final IDatabaseConnection connection;
54
55 private final PkTableMap allowedPKsPerTable;
56 private final PkTableMap allowedPKsInput;
57 private final PkTableMap pksToScanPerTable;
58
59 private final boolean reverseScan;
60
61 protected final Logger logger = LoggerFactory.getLogger(getClass());
62
63
64 private final Map pkColumnPerTable = new HashMap();
65
66 private final Map fkEdgesPerTable = new HashMap();
67 private final Map fkReverseEdgesPerTable = new HashMap();
68
69
70 private final List tableNames = new ArrayList();
71
72
73
74
75
76
77
78
79
80
81
82 public PrimaryKeyFilter(IDatabaseConnection connection, PkTableMap allowedPKs, boolean reverseDependency) {
83 this.connection = connection;
84 this.allowedPKsPerTable = new PkTableMap();
85 this.allowedPKsInput = allowedPKs;
86 this.reverseScan = reverseDependency;
87
88
89 this.pksToScanPerTable = new PkTableMap(allowedPKs);
90 }
91
92 public void nodeAdded(Object node) {
93 this.tableNames.add( node );
94 if ( this.logger.isDebugEnabled() ) {
95 this.logger.debug("nodeAdded: " + node );
96 }
97 }
98
99 public void edgeAdded(ForeignKeyRelationshipEdge edge) {
100 if ( this.logger.isDebugEnabled() ) {
101 this.logger.debug("edgeAdded: " + edge );
102 }
103
104 String from = (String) edge.getFrom();
105 Set edges = (Set) this.fkEdgesPerTable.get(from);
106 if ( edges == null ) {
107 edges = new HashSet();
108 this.fkEdgesPerTable.put( from, edges );
109 }
110 if ( ! edges.contains(edge) ) {
111 edges.add(edge);
112 }
113
114
115 String to = (String) edge.getTo();
116 edges = (Set) this.fkReverseEdgesPerTable.get(to);
117 if ( edges == null ) {
118 edges = new HashSet();
119 this.fkReverseEdgesPerTable.put(to, edges);
120 }
121 if ( ! edges.contains(edge) ) {
122 edges.add(edge);
123 }
124
125
126 updatePkCache(to, edge);
127
128 }
129
130
131
132
133 public boolean isValidName(String tableName) throws DataSetException {
134
135
136 return true;
137 }
138
139 public ITableIterator iterator(IDataSet dataSet, boolean reversed)
140 throws DataSetException {
141 if ( this.logger.isDebugEnabled() ) {
142 this.logger.debug("Filter.iterator()" );
143 }
144 try {
145 searchPKs(dataSet);
146 } catch (SQLException e) {
147 throw new DataSetException( e );
148 }
149 return new FilterIterator(reversed ? dataSet.reverseIterator() : dataSet
150 .iterator());
151 }
152
153 private void searchPKs(IDataSet dataSet) throws DataSetException, SQLException {
154 logger.debug("searchPKs(dataSet={}) - start", dataSet);
155
156 int counter = 0;
157 while ( !this.pksToScanPerTable.isEmpty() ) {
158 counter ++;
159 if ( this.logger.isDebugEnabled() ) {
160 this.logger.debug( "RUN # " + counter );
161 }
162
163 for( int i=this.tableNames.size()-1; i>=0; i-- ) {
164 String tableName = (String) this.tableNames.get(i);
165
166 String pkColumn = dataSet.getTable(tableName).getTableMetaData().getPrimaryKeys()[0].getColumnName();
167 Set tmpSet = this.pksToScanPerTable.get( tableName );
168 if ( tmpSet != null && ! tmpSet.isEmpty() ) {
169 Set pksToScan = new HashSet( tmpSet );
170 if ( this.logger.isDebugEnabled() ) {
171 this.logger.debug( "before search: "+ tableName + "=>" + pksToScan );
172 }
173 scanPKs( tableName, pkColumn, pksToScan );
174 scanReversePKs( tableName, pksToScan );
175 allowPKs( tableName, pksToScan );
176 removePKsToScan( tableName, pksToScan );
177 }
178 }
179 removeScannedTables();
180 }
181 if ( this.logger.isDebugEnabled() ) {
182 this.logger.debug( "Finished searchIds()" );
183 }
184 }
185
186 private void removeScannedTables() {
187 logger.debug("removeScannedTables() - start");
188 this.pksToScanPerTable.retainOnly(this.tableNames);
189 }
190
191 private void allowPKs(String table, Set newAllowedPKs) {
192 logger.debug("allowPKs(table={}, newAllowedPKs={}) - start", table, newAllowedPKs);
193
194
195 Set forcedAllowedPKs = this.allowedPKsInput.get( table );
196 if( forcedAllowedPKs == null || forcedAllowedPKs.isEmpty() ) {
197 allowedPKsPerTable.addAll(table, newAllowedPKs );
198 } else {
199 for(Iterator iterator = newAllowedPKs.iterator(); iterator.hasNext(); ) {
200 Object id = iterator.next();
201 if( forcedAllowedPKs.contains(id) ) {
202 allowedPKsPerTable.add(table, id);
203 }
204 else
205 {
206 logger.debug("Discarding id {} of table {} as it was not included in the input!", id, table);
207 }
208 }
209 }
210 }
211
212 private void scanPKs( String table, String pkColumn, Set allowedIds ) throws SQLException {
213 if (logger.isDebugEnabled())
214 {
215 logger.debug("scanPKs(table={}, pkColumn={}, allowedIds={}) - start",
216 new Object[]{ table, pkColumn, allowedIds });
217 }
218
219 Set fkEdges = (Set) this.fkEdgesPerTable.get( table );
220 if ( fkEdges == null || fkEdges.isEmpty() ) {
221 return;
222 }
223
224 List fkTables = new ArrayList( fkEdges.size() );
225 final StringBuilder colsBuffer = new StringBuilder();
226 for(Iterator iterator = fkEdges.iterator(); iterator.hasNext(); ) {
227 ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
228 fkTables.add( edge.getTo() );
229 colsBuffer.append( edge.getFKColumn() );
230 if ( iterator.hasNext() ) {
231 colsBuffer.append( ", " );
232 }
233 }
234
235 String sql = "SELECT " + colsBuffer + " FROM " + table +
236 " WHERE " + pkColumn + " = ? ";
237 if ( this.logger.isDebugEnabled() ) {
238 this.logger.debug( "SQL: " + sql );
239 }
240
241 scanPKs(table, sql, allowedIds, fkTables);
242 }
243
244 private void scanPKs(String table, String sql, Set allowedIds, List fkTables) throws SQLException
245 {
246 PreparedStatement pstmt = null;
247 ResultSet rs = null;
248 try {
249 pstmt = this.connection.getConnection().prepareStatement( sql );
250 for(Iterator iterator = allowedIds.iterator(); iterator.hasNext(); ) {
251 Object pk = iterator.next();
252 if( this.logger.isDebugEnabled() ) {
253 this.logger.debug("Executing sql for ? = " + pk );
254 }
255 pstmt.setObject( 1, pk );
256 rs = pstmt.executeQuery();
257 while( rs.next() ) {
258 for( int i=0; i<fkTables.size(); i++ ) {
259 String newTable = (String) fkTables.get(i);
260 Object fk = rs.getObject(i+1);
261 if( fk != null ) {
262 logger.debug("New ID: {}->{}", newTable, fk);
263 addPKToScan( newTable, fk );
264 }
265 else {
266 logger.warn( "Found null FK for relationship {} =>{}", table, newTable );
267 }
268 }
269 }
270 }
271 } catch (SQLException e) {
272 logger.error("scanPKs()", e);
273 }
274 finally {
275
276 SQLHelper.close( rs, pstmt );
277 }
278 }
279
280 private void scanReversePKs(String table, Set pksToScan) throws SQLException {
281 logger.debug("scanReversePKs(table={}, pksToScan={}) - start", table, pksToScan);
282
283 if ( ! this.reverseScan ) {
284 return;
285 }
286 Set fkReverseEdges = (Set) this.fkReverseEdgesPerTable.get( table );
287 if ( fkReverseEdges == null || fkReverseEdges.isEmpty() ) {
288 return;
289 }
290 Iterator iterator = fkReverseEdges.iterator();
291 while ( iterator.hasNext() ) {
292 ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
293 addReverseEdge( edge, pksToScan );
294 }
295 }
296
297 private void addReverseEdge(ForeignKeyRelationshipEdge edge, Set idsToScan) throws SQLException {
298 logger.debug("addReverseEdge(edge={}, idsToScan=) - start", edge, idsToScan);
299
300 String fkTable = (String) edge.getFrom();
301 String fkColumn = edge.getFKColumn();
302 String pkColumn = getPKColumn( fkTable );
303
304 String sql = "SELECT " + pkColumn + " FROM " + fkTable + " WHERE " + fkColumn + " = ? ";
305
306 PreparedStatement pstmt = null;
307 ResultSet rs = null;
308 try {
309 logger.debug("Preparing SQL query '{}'", sql);
310
311 pstmt = this.connection.getConnection().prepareStatement(sql);
312 for(Iterator iterator = idsToScan.iterator(); iterator.hasNext(); ) {
313 Object pk = iterator.next();
314 if ( this.logger.isDebugEnabled() ) {
315 this.logger.debug( "executing query '" + sql + "' for ? = " + pk );
316 }
317 pstmt.setObject( 1, pk );
318 rs = pstmt.executeQuery();
319 while( rs.next() ) {
320 Object fk = rs.getObject(1);
321 addPKToScan( fkTable, fk );
322 }
323 }
324 } finally {
325 SQLHelper.close( rs, pstmt );
326 }
327 }
328
329 private void updatePkCache(String table, ForeignKeyRelationshipEdge edge) {
330 logger.debug("updatePkCache(to={}, edge={}) - start", table, edge);
331
332 Object pkTo = this.pkColumnPerTable.get(table);
333 if ( pkTo == null ) {
334 String pkColumn = edge.getPKColumn();
335 this.pkColumnPerTable.put( table, pkColumn );
336 }
337 }
338
339
340 private String getPKColumn( String table ) throws SQLException {
341 logger.debug("getPKColumn(table={}) - start", table);
342
343
344 String pkColumn = (String) this.pkColumnPerTable.get( table );
345 if ( pkColumn == null ) {
346
347 pkColumn = SQLHelper.getPrimaryKeyColumn( this.connection.getConnection(), table );
348 this.pkColumnPerTable.put( table, pkColumn );
349 }
350 return pkColumn;
351 }
352
353
354 private void removePKsToScan(String table, Set ids) {
355 logger.debug("removePKsToScan(table={}, ids={}) - start", table, ids);
356
357 Set pksToScan = this.pksToScanPerTable.get(table);
358 if ( pksToScan != null ) {
359 if ( pksToScan == ids ) {
360 throw new RuntimeException( "INTERNAL ERROR on removeIdsToScan() for table " + table );
361 } else {
362 pksToScan.removeAll( ids );
363 }
364 }
365 }
366
367 private void addPKToScan(String table, Object pk) {
368 logger.debug("addPKToScan(table={}, pk={}) - start", table, pk);
369
370
371 if(this.allowedPKsPerTable.contains(table, pk)) {
372 if ( this.logger.isDebugEnabled() ) {
373 this.logger.debug( "Discarding already scanned id=" + pk + " for table " + table );
374 }
375 return;
376 }
377
378 this.pksToScanPerTable.add(table, pk);
379 }
380
381 public String toString() {
382 final StringBuilder sb = new StringBuilder();
383 sb.append("tableNames=").append(tableNames);
384 sb.append(", allowedPKsInput=").append(allowedPKsInput);
385 sb.append(", allowedPKsPerTable=").append(allowedPKsPerTable);
386 sb.append(", fkEdgesPerTable=").append(fkEdgesPerTable);
387 sb.append(", fkReverseEdgesPerTable=").append(fkReverseEdgesPerTable);
388 sb.append(", pkColumnPerTable=").append(pkColumnPerTable);
389 sb.append(", pksToScanPerTable=").append(pksToScanPerTable);
390 sb.append(", reverseScan=").append(reverseScan);
391 sb.append(", connection=").append(connection);
392 return sb.toString();
393 }
394
395
396 private class FilterIterator implements ITableIterator {
397
398 private final ITableIterator _iterator;
399
400 public FilterIterator(ITableIterator iterator) {
401
402 _iterator = iterator;
403 }
404
405
406
407
408 public boolean next() throws DataSetException {
409 if ( logger.isDebugEnabled() ) {
410 logger.debug("Iterator.next()" );
411 }
412 while (_iterator.next()) {
413 if (accept(_iterator.getTableMetaData().getTableName())) {
414 return true;
415 }
416 }
417 return false;
418 }
419
420 public ITableMetaData getTableMetaData() throws DataSetException {
421 if ( logger.isDebugEnabled() ) {
422 logger.debug("Iterator.getTableMetaData()" );
423 }
424 return _iterator.getTableMetaData();
425 }
426
427 public ITable getTable() throws DataSetException {
428 if ( logger.isDebugEnabled() ) {
429 logger.debug("Iterator.getTable()" );
430 }
431 ITable table = _iterator.getTable();
432 String tableName = table.getTableMetaData().getTableName();
433 Set allowedPKs = allowedPKsPerTable.get( tableName );
434 if ( allowedPKs != null ) {
435 return new PrimaryKeyFilteredTableWrapper(table, allowedPKs);
436 }
437 return table;
438 }
439 }
440
441
442
443
444
445
446
447
448
449 public static class PkTableMap
450 {
451 private final LinkedHashMap pksPerTable;
452 private final Logger logger = LoggerFactory.getLogger(PkTableMap.class);
453
454 public PkTableMap()
455 {
456 this.pksPerTable = new LinkedHashMap();
457 }
458
459
460
461
462
463 public PkTableMap(PkTableMap allowedPKs) {
464 this.pksPerTable = new LinkedHashMap();
465 Iterator iterator = allowedPKs.pksPerTable.entrySet().iterator();
466 while ( iterator.hasNext() ) {
467 Map.Entry entry = (Map.Entry) iterator.next();
468 String table = (String)entry.getKey();
469 SortedSet pkObjectSet = (SortedSet) entry.getValue();
470 SortedSet newSet = new TreeSet( pkObjectSet );
471 this.pksPerTable.put( table, newSet );
472 }
473 }
474
475 public int size() {
476 return pksPerTable.size();
477 }
478
479 public boolean isEmpty() {
480 return pksPerTable.isEmpty();
481 }
482
483 public boolean contains(String table, Object pkObject) {
484 Set pksPerTable = this.get(table);
485 return (pksPerTable != null && pksPerTable.contains(pkObject));
486 }
487
488 public void remove(String tableName) {
489 this.pksPerTable.remove(tableName);
490 }
491
492 public void put(String table, SortedSet pkObjects) {
493 this.pksPerTable.put(table, pkObjects);
494 }
495
496 public void add(String tableName, Object pkObject) {
497 Set pksPerTable = getCreateIfNeeded(tableName);
498 pksPerTable.add(pkObject);
499 }
500
501 public void addAll(String tableName, Set pkObjectsToAdd) {
502 Set pksPerTable = this.getCreateIfNeeded(tableName);
503 pksPerTable.addAll(pkObjectsToAdd);
504 }
505
506 public SortedSet get(String tableName) {
507 return (SortedSet) this.pksPerTable.get(tableName);
508 }
509
510 private SortedSet getCreateIfNeeded(String tableName){
511 SortedSet pksPerTable = this.get(tableName);
512
513 if( pksPerTable == null ) {
514 pksPerTable = new TreeSet();
515 this.pksPerTable.put(tableName, pksPerTable);
516 }
517 return pksPerTable;
518 }
519
520 public String[] getTableNames() {
521 return (String[]) this.pksPerTable.keySet().toArray(new String[0]);
522 }
523
524 public void retainOnly(List tableNames) {
525
526 List tablesToRemove = new ArrayList();
527 for(Iterator iterator = this.pksPerTable.entrySet().iterator(); iterator.hasNext(); ) {
528 Map.Entry entry = (Map.Entry) iterator.next();
529 String table = (String) entry.getKey();
530 SortedSet pksToScan = (SortedSet) entry.getValue();
531 boolean removeIt = pksToScan.isEmpty();
532
533 if ( ! tableNames.contains(table) ) {
534 if ( this.logger.isWarnEnabled() ) {
535 this.logger.warn("Discarding ids " + pksToScan + " of table " + table +
536 "as this table has not been passed as input" );
537 }
538 removeIt = true;
539 }
540 if ( removeIt ) {
541 tablesToRemove.add( table );
542 }
543 }
544
545 for(Iterator iterator = tablesToRemove.iterator(); iterator.hasNext(); ) {
546 this.remove( (String)iterator.next() );
547 }
548 }
549
550
551 public String toString() {
552 final StringBuilder sb = new StringBuilder();
553 sb.append("pKsPerTable=").append(pksPerTable);
554 return sb.toString();
555 }
556
557 }
558 }