Repository : ssh://g18-sc-serv-04.diamond.ac.uk/dials_scratch
On branch : master
commit ae9c94da11153b12f38a3fd4ad534cd9572e83c9
Author: James Beilsten-Edmands <[log in to unmask]>
Date: Wed Jul 4 14:51:37 2018 +0100
Update cross validation script for recent scaling code updates
ae9c94da11153b12f38a3fd4ad534cd9572e83c9
command_line/scale_cross_validate.py | 22 +++++++++++++++-------
1 file changed, 15 insertions(+), 7 deletions(-)
diff --git a/command_line/scale_cross_validate.py b/command_line/scale_cross_validate.py
index 57c7b9d..5492886 100644
--- a/command_line/scale_cross_validate.py
+++ b/command_line/scale_cross_validate.py
@@ -112,7 +112,7 @@ def cross_validate():
results_dict = {}
for i, v in enumerate(itertools.product(*values)):
e = dict(zip(keys, v))
- results_dict[i] = {"configuration": [], "Rwork": [], "Rfree": []}
+ results_dict[i] = {"configuration": [], "Rwork": [], "Rfree": [], "CCwork": [], "CCfree": []}
for k, v in e.iteritems():
params = set_parameter(params, k, v)
results_dict[i]["configuration"].append(str(k)+'='+str(v))
@@ -130,7 +130,7 @@ def cross_validate():
k = params.cross_validation.optimise_choice
params = set_parameter(params, k, value)
results_dict[i] = {"configuration": [str(k)+'='+str(value)],
- "Rwork": [], "Rfree": []}
+ "Rwork": [], "Rfree": [], "CCwork": [], "CCfree": []}
for n in range(params.cross_validation.nfolds):
if n < 100.0/params.scaling_options.free_set_percentage:
params.scaling_options.free_set_offset = n
@@ -145,7 +145,7 @@ def cross_validate():
k = params.cross_validation.optimise_parameter
params = set_parameter(params, k, value)
results_dict[i] = {"configuration": [str(k)+'='+str(value)],
- "Rwork": [], "Rfree": []}
+ "Rwork": [], "Rfree": [], "CCwork": [], "CCfree": []}
for n in range(params.cross_validation.nfolds):
if n < 100.0/params.scaling_options.free_set_percentage:
params.scaling_options.free_set_offset = n
@@ -188,8 +188,10 @@ def run_script(params, experiments, reflections, results_dict):
script = Script(params, experiments=deepcopy(experiments),
reflections=deepcopy(reflections))
script.run(save_data=False)
- results_dict["Rwork"].append(script.minimised.final_rmsds[3])
- results_dict["Rfree"].append(script.minimised.final_rmsds[4])
+ results_dict["Rwork"].append(script.scaler.final_rmsds[0])
+ results_dict["Rfree"].append(script.scaler.final_rmsds[1])
+ results_dict["CCwork"].append(script.scaler.final_rmsds[2])
+ results_dict["CCfree"].append(script.scaler.final_rmsds[3])
return results_dict
def interpret_results(results_dict):
@@ -197,17 +199,23 @@ def interpret_results(results_dict):
Expect a configuration and final_rmsds columns. Score the data and make a
nice table."""
rows = []
- headers = ['option', 'Rwork', 'Rfree']
+ headers = ['option', 'Rwork', 'Rfree', 'CCwork', 'CCfree']
free_rmsds = []
+ free_cc12s = []
for v in results_dict.itervalues():
config_str = ' '.join(v['configuration'])
avg_work = round(sum(v['Rwork'])/len(v['Rwork']), 5)
avg_free = round(sum(v['Rfree'])/len(v['Rfree']), 5)
- rows.append([config_str, str(avg_work), str(avg_free)])
+ avg_ccwork = round(sum(v['CCwork'])/len(v['CCwork']), 5)
+ avg_ccfree = round(sum(v['CCfree'])/len(v['CCfree']), 5)
+ rows.append([config_str, str(avg_work), str(avg_free), str(avg_ccwork), str(avg_ccfree)])
free_rmsds.append(avg_free)
+ free_cc12s.append(avg_ccfree)
#find lowest free rmsd
low_rmsd_idx = free_rmsds.index(min(free_rmsds))
+ high_cc12_idx = free_cc12s.index(max(free_cc12s))
rows[low_rmsd_idx][2] += '*'
+ rows[high_cc12_idx][4] += '*'
st = simple_table(rows, headers)
logger.info('Summary of the cross validation analysis: \n')
logger.info(st.format())