Repository : ssh://g18-sc-serv-04.diamond.ac.uk/dials_scratch On branch : master >--------------------------------------------------------------- commit ae9c94da11153b12f38a3fd4ad534cd9572e83c9 Author: James Beilsten-Edmands <[log in to unmask]> Date: Wed Jul 4 14:51:37 2018 +0100 Update cross validation script for recent scaling code updates >--------------------------------------------------------------- ae9c94da11153b12f38a3fd4ad534cd9572e83c9 command_line/scale_cross_validate.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/command_line/scale_cross_validate.py b/command_line/scale_cross_validate.py index 57c7b9d..5492886 100644 --- a/command_line/scale_cross_validate.py +++ b/command_line/scale_cross_validate.py @@ -112,7 +112,7 @@ def cross_validate(): results_dict = {} for i, v in enumerate(itertools.product(*values)): e = dict(zip(keys, v)) - results_dict[i] = {"configuration": [], "Rwork": [], "Rfree": []} + results_dict[i] = {"configuration": [], "Rwork": [], "Rfree": [], "CCwork": [], "CCfree": []} for k, v in e.iteritems(): params = set_parameter(params, k, v) results_dict[i]["configuration"].append(str(k)+'='+str(v)) @@ -130,7 +130,7 @@ def cross_validate(): k = params.cross_validation.optimise_choice params = set_parameter(params, k, value) results_dict[i] = {"configuration": [str(k)+'='+str(value)], - "Rwork": [], "Rfree": []} + "Rwork": [], "Rfree": [], "CCwork": [], "CCfree": []} for n in range(params.cross_validation.nfolds): if n < 100.0/params.scaling_options.free_set_percentage: params.scaling_options.free_set_offset = n @@ -145,7 +145,7 @@ def cross_validate(): k = params.cross_validation.optimise_parameter params = set_parameter(params, k, value) results_dict[i] = {"configuration": [str(k)+'='+str(value)], - "Rwork": [], "Rfree": []} + "Rwork": [], "Rfree": [], "CCwork": [], "CCfree": []} for n in range(params.cross_validation.nfolds): if n < 100.0/params.scaling_options.free_set_percentage: params.scaling_options.free_set_offset = n @@ -188,8 +188,10 @@ def run_script(params, experiments, reflections, results_dict): script = Script(params, experiments=deepcopy(experiments), reflections=deepcopy(reflections)) script.run(save_data=False) - results_dict["Rwork"].append(script.minimised.final_rmsds[3]) - results_dict["Rfree"].append(script.minimised.final_rmsds[4]) + results_dict["Rwork"].append(script.scaler.final_rmsds[0]) + results_dict["Rfree"].append(script.scaler.final_rmsds[1]) + results_dict["CCwork"].append(script.scaler.final_rmsds[2]) + results_dict["CCfree"].append(script.scaler.final_rmsds[3]) return results_dict def interpret_results(results_dict): @@ -197,17 +199,23 @@ def interpret_results(results_dict): Expect a configuration and final_rmsds columns. Score the data and make a nice table.""" rows = [] - headers = ['option', 'Rwork', 'Rfree'] + headers = ['option', 'Rwork', 'Rfree', 'CCwork', 'CCfree'] free_rmsds = [] + free_cc12s = [] for v in results_dict.itervalues(): config_str = ' '.join(v['configuration']) avg_work = round(sum(v['Rwork'])/len(v['Rwork']), 5) avg_free = round(sum(v['Rfree'])/len(v['Rfree']), 5) - rows.append([config_str, str(avg_work), str(avg_free)]) + avg_ccwork = round(sum(v['CCwork'])/len(v['CCwork']), 5) + avg_ccfree = round(sum(v['CCfree'])/len(v['CCfree']), 5) + rows.append([config_str, str(avg_work), str(avg_free), str(avg_ccwork), str(avg_ccfree)]) free_rmsds.append(avg_free) + free_cc12s.append(avg_ccfree) #find lowest free rmsd low_rmsd_idx = free_rmsds.index(min(free_rmsds)) + high_cc12_idx = free_cc12s.index(max(free_cc12s)) rows[low_rmsd_idx][2] += '*' + rows[high_cc12_idx][4] += '*' st = simple_table(rows, headers) logger.info('Summary of the cross validation analysis: \n') logger.info(st.format()) ######################################################################## To unsubscribe from the DIALS-COMMIT list, click the following link: https://www.jiscmail.ac.uk/cgi-bin/webadmin?SUBED1=DIALS-COMMIT&A=1