Created
January 29, 2014 21:26
-
-
Save s1monw/8697560 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Licensed to Elasticsearch under one or more contributor | |
* license agreements. See the NOTICE file distributed with | |
* this work for additional information regarding copyright | |
* ownership. Elasticsearch licenses this file to you under | |
* the Apache License, Version 2.0 (the "License"); you may | |
* not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, | |
* software distributed under the License is distributed on an | |
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
* KIND, either express or implied. See the License for the | |
* specific language governing permissions and limitations | |
* under the License. | |
*/ | |
package org.apache.lucene.queries; | |
import org.apache.lucene.index.*; | |
import org.apache.lucene.search.*; | |
import java.io.IOException; | |
import java.util.Arrays; | |
import java.util.List; | |
/** | |
*/ | |
public abstract class BlendedTermQuery extends Query { | |
private final String term; | |
private final String[] fields; | |
private final float bias; | |
public BlendedTermQuery(String[] fields, String term) { | |
this(fields, term, 0.0f); | |
} | |
public BlendedTermQuery(String[] fields, String term, float bias) { | |
if (fields == null) { | |
throw new IllegalArgumentException("fields must not be null"); | |
} | |
if (term == null) { | |
throw new IllegalArgumentException("term must note be null"); | |
} | |
this.fields = fields; | |
this.bias = bias; | |
this.term = term; | |
} | |
@Override | |
public Query rewrite(IndexReader reader) throws IOException { | |
IndexReaderContext context = reader.getContext(); | |
Term[] terms = new Term[fields.length]; | |
TermContext[] ctx = new TermContext[fields.length]; | |
for (int i = 0; i < terms.length; i++) { | |
terms[i] = new Term(fields[i], term); | |
ctx[i] = TermContext.build(context, terms[i]); | |
} | |
blend(ctx, reader.maxDoc()); | |
return topLevelQuery(terms, ctx); | |
} | |
protected abstract Query topLevelQuery(Term[] terms, TermContext[] ctx); | |
protected void blend(TermContext[] contexts, int maxDoc) { | |
long sum = 0; | |
int numZeroDF = 0; | |
for (TermContext ctx : contexts) { | |
int df = ctx.docFreq(); | |
sum += df; | |
if (df == 0) { | |
numZeroDF++; | |
} | |
} | |
if (sum == 0) { | |
return; // we are done that term doesn't exist at all | |
} | |
final long avg = sum / (contexts.length - numZeroDF); | |
for (int i = 0; i < contexts.length; i++) { | |
int df = contexts[i].docFreq(); | |
if (df == 0) { | |
continue; | |
} | |
long blendedIDF = avg + ((long) (bias * (df - avg))); | |
contexts[i].setDocFreq((int) Math.min(maxDoc, blendedIDF)); | |
contexts[i] = adjustTTF(contexts[i]); | |
} | |
} | |
private TermContext adjustTTF(TermContext termContext) { | |
if (termContext.docFreq() > termContext.totalTermFreq()) { | |
TermContext newTermContext = new TermContext(termContext.topReaderContext); | |
List<AtomicReaderContext> leaves = termContext.topReaderContext.leaves(); | |
final int len; | |
if (leaves == null) { | |
len = 1; | |
} else { | |
len = leaves.size(); | |
} | |
int df = termContext.docFreq(); | |
long ttf = Math.max(df, termContext.totalTermFreq()); | |
for (int i = 0; i < len; i++) { | |
TermState termState = termContext.get(i); | |
if (termState == null) { | |
continue; | |
} | |
newTermContext.register(termState, i, df, ttf); | |
df = 0; | |
ttf = 0; | |
} | |
return newTermContext; | |
} | |
return termContext; | |
} | |
@Override | |
public String toString(String field) { | |
return "blended(\"" + term + "\", fields: " + Arrays.toString(fields) + ")"; | |
} | |
@Override | |
public boolean equals(Object o) { | |
if (this == o) return true; | |
if (o == null || getClass() != o.getClass()) return false; | |
if (!super.equals(o)) return false; | |
BlendedTermQuery that = (BlendedTermQuery) o; | |
if (Float.compare(that.bias, bias) != 0) return false; | |
if (!Arrays.equals(fields, that.fields)) return false; | |
if (!term.equals(that.term)) return false; | |
return true; | |
} | |
@Override | |
public int hashCode() { | |
int result = super.hashCode(); | |
result = 31 * result + term.hashCode(); | |
result = 31 * result + Arrays.hashCode(fields); | |
result = 31 * result + Float.floatToIntBits(bias); | |
return result; | |
} | |
public static BlendedTermQuery booleanBlendedQuery(String[] fields, String term, float bias, final boolean disableCoord) { | |
return new BlendedTermQuery(fields, term, bias) { | |
protected Query topLevelQuery(Term[] terms, TermContext[] ctx) { | |
BooleanQuery query = new BooleanQuery(disableCoord); | |
for (int i = 0; i < terms.length; i++) { | |
query.add(new TermQuery(terms[i], ctx[i]), BooleanClause.Occur.SHOULD); | |
} | |
return query; | |
} | |
}; | |
} | |
public static BlendedTermQuery dismaxBlendedQuery(String[] fields, String term, float bias, final float tieBreakerMultiplier ) { | |
return new BlendedTermQuery(fields, term, bias) { | |
protected Query topLevelQuery(Term[] terms, TermContext[] ctx) { | |
DisjunctionMaxQuery query = new DisjunctionMaxQuery(tieBreakerMultiplier); | |
for (int i = 0; i < terms.length; i++) { | |
query.add(new TermQuery(terms[i], ctx[i])); | |
} | |
return query; | |
} | |
}; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment