MtasSolrResultMerge.java

package mtas.solr.handler.component.util;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map.Entry;
import java.util.SortedSet;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SimpleOrderedMap;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.ShardRequest;
import org.apache.solr.handler.component.ShardResponse;

import mtas.solr.handler.component.MtasSolrSearchComponent;

/**
 * The Class MtasSolrResultMerge.
 */
public class MtasSolrResultMerge {

  /** The Constant log. */
  private static final Log log = LogFactory.getLog(MtasSolrResultMerge.class);

  /**
   * Merge.
   *
   * @param rb the rb
   */
  @SuppressWarnings("unchecked")
  public void merge(ResponseBuilder rb) {
    if (rb.req.getParams().getBool(MtasSolrSearchComponent.PARAM_MTAS, false)) {
      // mtas response
      NamedList<Object> mtasResponse = null;
      boolean newResponse = false;
      try {
        mtasResponse = (NamedList<Object>) rb.rsp.getValues().get(MtasSolrSearchComponent.NAME);        
      } catch (ClassCastException e) {
        log.debug(e);
        mtasResponse = null;        
      }
      if (mtasResponse == null) {
        newResponse = true;
        mtasResponse = new SimpleOrderedMap<>();               
      }

      for (ShardRequest sreq : rb.finished) {
        if (rb.stage == ResponseBuilder.STAGE_EXECUTE_QUERY) {
          // merge stats
          if (rb.req.getParams()
              .getBool(MtasSolrComponentStats.PARAM_MTAS_STATS, false)) {
            mergeNamedList(sreq, mtasResponse, MtasSolrComponentStats.NAME, null);
          }
          // merge group
          if (rb.req.getParams()
              .getBool(MtasSolrComponentGroup.PARAM_MTAS_GROUP, false)) {
            mergeArrayList(sreq, mtasResponse, MtasSolrComponentGroup.NAME, null, false);
          }
          // merge facet
          if (rb.req.getParams()
              .getBool(MtasSolrComponentFacet.PARAM_MTAS_FACET, false)) {
            mergeArrayList(sreq, mtasResponse, MtasSolrComponentFacet.NAME, null, false);
          }
          // merge collection
          if (rb.req.getParams().getBool(
              MtasSolrComponentCollection.PARAM_MTAS_COLLECTION, false)) {
            mergeArrayList(sreq, mtasResponse, MtasSolrComponentCollection.NAME, null, false);
          }
          // merge prefix
          if (rb.req.getParams()
              .getBool(MtasSolrComponentPrefix.PARAM_MTAS_PREFIX, false)) {
            mergeArrayList(sreq, mtasResponse, MtasSolrComponentPrefix.NAME, null, false);
          }
        } else if (rb.stage == MtasSolrSearchComponent.STAGE_COLLECTION_INIT) {
          // merge collection
          if (rb.req.getParams().getBool(
              MtasSolrComponentCollection.PARAM_MTAS_COLLECTION, false)) {
            mergeArrayList(sreq, mtasResponse, MtasSolrComponentCollection.NAME, null, false);
          }
        } else if (rb.stage == MtasSolrSearchComponent.STAGE_TERMVECTOR_MISSING_KEY) {
          // merge termvector
          if (rb.req.getParams().getBool(
              MtasSolrComponentTermvector.PARAM_MTAS_TERMVECTOR, false)) {
            mergeArrayList(sreq, mtasResponse, MtasSolrComponentTermvector.NAME, null, false);
          }
        } else if (rb.stage == MtasSolrSearchComponent.STAGE_LIST) {
          // merge list
          if (rb.req.getParams().getBool(MtasSolrComponentList.PARAM_MTAS_LIST,
              false)) {
            mergeArrayList(sreq, mtasResponse, MtasSolrComponentList.NAME,
                ShardRequest.PURPOSE_PRIVATE, true);
          }
        } else if (rb.stage == ResponseBuilder.STAGE_GET_FIELDS) {
          // merge document
          if (rb.req.getParams()
              .getBool(MtasSolrComponentDocument.PARAM_MTAS_DOCUMENT, false)) {
            mergeArrayList(sreq, mtasResponse, MtasSolrComponentDocument.NAME,
                ShardRequest.PURPOSE_PRIVATE, true);
          }
          // merge kwic
          if (rb.req.getParams().getBool(MtasSolrComponentKwic.PARAM_MTAS_KWIC,
              false)) {
            mergeArrayList(sreq, mtasResponse, MtasSolrComponentKwic.NAME,
                ShardRequest.PURPOSE_PRIVATE, true);
          }
        }
      }
      if(newResponse && mtasResponse.size()>0) {
        rb.rsp.add(MtasSolrSearchComponent.NAME, mtasResponse);
      }
    }
  }

  /**
   * Merge named list.
   *
   * @param sreq the sreq
   * @param mtasResponse the mtas response
   * @param key the key
   * @param preferredPurpose the preferred purpose
   */
  @SuppressWarnings("unchecked")
  private void mergeNamedList(ShardRequest sreq, NamedList<Object> mtasResponse,
      String key, Integer preferredPurpose) {
    // create new response for key
    NamedList<Object> mtasListResponse;
    Object o = mtasResponse.get(key);
    if (o instanceof NamedList) {
      mtasListResponse = (NamedList<Object>) o;
    } else {
      mtasListResponse = new SimpleOrderedMap<>();
      mtasResponse.removeAll(key);
      mtasResponse.add(key, mtasListResponse);
    }
    // collect responses for each shard
    HashMap<String, NamedList<Object>> mtasListShardResponses = new HashMap<>();
    for (ShardResponse response : sreq.responses) {
      // only continue if new shard or preferred purpose
      if (mtasListShardResponses.containsKey(response.getShard())
          && ((preferredPurpose == null)
              || (sreq.purpose != preferredPurpose))) {
        break;
      }
      // update
      try {
        NamedList<Object> result = response.getSolrResponse().getResponse();
        NamedList<Object> data = (NamedList<Object>) result
            .findRecursive("mtas", key);
        if (data != null) {
          mtasListShardResponses.put(response.getShard(),
              MtasSolrResultUtil.decode(data));
        }
      } catch (ClassCastException e) {
        log.debug(e);
      }
    }
    try {
      for (NamedList<Object> mtasListShardResponse : mtasListShardResponses
          .values()) {
        mergeResponsesNamedList(mtasListResponse, mtasListShardResponse);
      }
    } catch (IOException e) {
      log.error(e);
    }
  }

  /**
   * Merge array list.
   *
   * @param sreq the sreq
   * @param mtasResponse the mtas response
   * @param key the key
   * @param preferredPurpose the preferred purpose
   * @param mergeAllShardResponses the merge all shard responses
   */
  @SuppressWarnings("unchecked")
  private void mergeArrayList(ShardRequest sreq, NamedList<Object> mtasResponse,
      String key, Integer preferredPurpose, boolean mergeAllShardResponses) {
    // create new response for key
    ArrayList<Object> mtasListResponse;
    Object o = mtasResponse.get(key);
    if (o instanceof ArrayList) {
      mtasListResponse = (ArrayList<Object>) o;
    } else {
      mtasListResponse = new ArrayList<>();
      mtasResponse.removeAll(key);
      mtasResponse.add(key, mtasListResponse);
    }
    // collect responses for each shard
    HashMap<String, ArrayList<Object>> mtasListShardResponses = new HashMap<>();
    ArrayList<ArrayList<Object>> mtasListShardResponsesExtra = new ArrayList<>();
    for (ShardResponse response : sreq.responses) {
      // only continue if new shard or preferred purpose
      if (mtasListShardResponses.containsKey(response.getShard())
          && ((preferredPurpose == null)
              || (sreq.purpose != preferredPurpose))) {
        break;
      }
      // update
      try {
        NamedList<Object> result = response.getSolrResponse().getResponse();
        ArrayList<Object> data = (ArrayList<Object>) result
            .findRecursive("mtas", key);
        if (data != null) {
          if (mtasListShardResponses.containsKey(response.getShardAddress())) {
            if (mergeAllShardResponses) {
              mtasListShardResponsesExtra.add(data);
            }
          } else {
            mtasListShardResponses.put(response.getShardAddress(), data);
          }
        }
      } catch (ClassCastException e) {
        log.error(e);
      }
    }

    try {
      for (ArrayList<Object> mtasListShardResponse : mtasListShardResponses
          .values()) {
        mergeResponsesArrayList(mtasListResponse, mtasListShardResponse);
      }
      for (ArrayList<Object> mtasListShardResponse : mtasListShardResponsesExtra) {
        mergeResponsesArrayList(mtasListResponse, mtasListShardResponse);
      }
    } catch (IOException e) {
      log.error(e);
    }
  }

  /**
   * Merge responses sorted set.
   *
   * @param originalList the original list
   * @param shardList the shard list
   */
  private void mergeResponsesSortedSet(SortedSet<Object> originalList,
      SortedSet<Object> shardList) {
    for (Object item : shardList) {
      originalList.add(item);
    }
  }

  /**
   * Merge responses array list.
   *
   * @param originalList the original list
   * @param shardList the shard list
   * @throws IOException Signals that an I/O exception has occurred.
   */
  @SuppressWarnings("unchecked")
  private void mergeResponsesArrayList(ArrayList<Object> originalList,
      ArrayList<Object> shardList) throws IOException {
    // get keys from original
    HashMap<String, Object> originalKeyList = new HashMap<>();
    for (Object item : originalList) {
      if (item instanceof NamedList<?>) {
        NamedList<Object> itemList = (NamedList<Object>) item;
        Object key = itemList.get("key");
        if ((key != null) && (key instanceof String)) {
          originalKeyList.put((String) key, item);
        }
      }
    }
    for (Object item : shardList) {
      if (item instanceof NamedList<?>) {
        NamedList<Object> itemList = (NamedList<Object>) item;
        Object key = itemList.get("key");
        // item with key
        if ((key != null) && (key instanceof String)) {
          // merge
          if (originalKeyList.containsKey(key)) {
            Object originalItem = originalKeyList.get(key);
            if (originalItem.getClass().equals(item.getClass())) {
              mergeResponsesNamedList((NamedList<Object>) originalItem,
                  (NamedList<Object>) item);
            } else {
              // ignore?
            }
            // add
          } else {
            Object clonedItem = adjustablePartsCloned(item);
            originalList.add(clonedItem);
            originalKeyList.put((String) key, clonedItem);
          }
        } else {
          originalList.add(item);
        }
      } else {
        originalList.add(item);
      }
    }
  }

  /**
   * Merge responses named list.
   *
   * @param mainResponse the main response
   * @param shardResponse the shard response
   * @throws IOException Signals that an I/O exception has occurred.
   */
  @SuppressWarnings({ "rawtypes", "unchecked" })
  private void mergeResponsesNamedList(NamedList<Object> mainResponse,
      NamedList<Object> shardResponse) throws IOException {
    Iterator<Entry<String, Object>> it = shardResponse.iterator();
    while (it.hasNext()) {
      Entry<String, Object> entry = it.next();
      String name = entry.getKey();
      Object shardValue = entry.getValue();
      int originalId = mainResponse.indexOf(name, 0);
      if (originalId < 0) {
        mainResponse.add(name, adjustablePartsCloned(shardValue));
      } else {
        Object original = mainResponse.getVal(originalId);
        if (original == null) {
          original = adjustablePartsCloned(shardValue);
        } else if (shardValue != null
            && original.getClass().equals(shardValue.getClass())) {
          // merge ArrayList
          if (original instanceof ArrayList) {
            ArrayList originalList = (ArrayList) original;
            ArrayList shardList = (ArrayList) shardValue;
            mergeResponsesArrayList(originalList, shardList);
            // merge Namedlist
          } else if (original instanceof NamedList<?>) {
            mergeResponsesNamedList((NamedList<Object>) original,
                (NamedList<Object>) shardValue);
            // merge SortedSet
          } else if (original instanceof SortedSet<?>) {
            mergeResponsesSortedSet((SortedSet<Object>) original,
                (SortedSet<Object>) shardValue);
          } else if (original instanceof MtasSolrMtasResult) {
            MtasSolrMtasResult originalComponentResult = (MtasSolrMtasResult) original;
            originalComponentResult.merge((MtasSolrMtasResult) shardValue);
          } else if (original instanceof MtasSolrCollectionResult) {
            MtasSolrCollectionResult originalComponentResult = (MtasSolrCollectionResult) original;
            originalComponentResult
                .merge((MtasSolrCollectionResult) shardValue);
          } else if (original instanceof String) {
            // ignore?
          } else if (original instanceof Integer) {
            original = (Integer) original + ((Integer) shardValue);
          } else if (original instanceof Long) {
            original = (Long) original + ((Long) shardValue);
          } else {
            // ignore?
          }
          mainResponse.setVal(originalId, original);
        } else {
          // ignore?
        }
      }
    }
  }

  /**
   * Adjustable parts cloned.
   *
   * @param original the original
   * @return the object
   */
  @SuppressWarnings({ "rawtypes", "unchecked" })
  private Object adjustablePartsCloned(Object original) {
    if (original instanceof NamedList) {
      NamedList<Object> newObject = new SimpleOrderedMap();
      NamedList<Object> originalObject = (NamedList<Object>) original;
      for (int i = 0; i < originalObject.size(); i++) {
        newObject.add(originalObject.getName(i),
            adjustablePartsCloned(originalObject.getVal(i)));
      }
      return newObject;
    } else if (original instanceof ArrayList) {
      ArrayList<Object> newObject = new ArrayList<>();
      ArrayList<Object> originalObject = (ArrayList<Object>) original;
      for (int i = 0; i < originalObject.size(); i++) {
        newObject.add(adjustablePartsCloned(originalObject.get(i)));
      }
      return newObject;
    } else if (original instanceof Integer) {
      return original;
    }
    return original;
  }

}